"vscode:/vscode.git/clone" did not exist on "07f95503e5b3e1823a5183d5999f4cb299becd19"
Unverified Commit 67dc1197 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Refactor] Use object system for all CAPIs (#716)

* WIP: using object system for graph

* c++ side refactoring done; compiled

* remove stale apis

* fix bug in DGLGraphCreate; passed test_graph.py

* fix bug in python modify; passed utest for pytorch/cpu

* fix lint

* address comments
parent b0d9e7aa
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <cstdint> #include <cstdint>
#include <utility> #include <utility>
#include <tuple> #include <tuple>
#include <memory>
#include "graph_interface.h" #include "graph_interface.h"
...@@ -19,21 +20,9 @@ namespace dgl { ...@@ -19,21 +20,9 @@ namespace dgl {
class Graph; class Graph;
class GraphOp; class GraphOp;
typedef std::shared_ptr<Graph> MutableGraphPtr;
/*! /*! \brief Mutable graph based on adjacency list. */
* \brief Base dgl graph index class.
*
* DGL's graph is directed. Vertices are integers enumerated from zero.
*
* Removal of vertices/edges is not allowed. Instead, the graph can only be "cleared"
* by removing all the vertices and edges.
*
* When calling functions supporing multiple edges (e.g. AddEdges, HasEdges),
* the input edges are represented by two id arrays for source and destination
* vertex ids. In the general case, the two arrays should have the same length.
* If the length of src id array is one, it represents one-many connections.
* If the length of dst id array is one, it represents many-one connections.
*/
class Graph: public GraphInterface { class Graph: public GraphInterface {
public: public:
/*! \brief default constructor */ /*! \brief default constructor */
...@@ -305,15 +294,6 @@ class Graph: public GraphInterface { ...@@ -305,15 +294,6 @@ class Graph: public GraphInterface {
*/ */
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override; Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
*
* \return the reversed graph
*/
GraphPtr Reverse() const override;
/*! /*!
* \brief Return the successor vector * \brief Return the successor vector
* \param vid The vertex id. * \param vid The vertex id.
...@@ -358,16 +338,6 @@ class Graph: public GraphInterface { ...@@ -358,16 +338,6 @@ class Graph: public GraphInterface {
return DGLIdIters(data, data + size); return DGLIdIters(data, data + size);
} }
/*!
* \brief Reset the data in the graph and move its data to the returned graph object.
* \return a raw pointer to the graph object.
*/
GraphInterface *Reset() override {
Graph* gptr = new Graph();
*gptr = std::move(*this);
return gptr;
}
/*! /*!
* \brief Get the adjacency matrix of the graph. * \brief Get the adjacency matrix of the graph.
* *
...@@ -379,6 +349,17 @@ class Graph: public GraphInterface { ...@@ -379,6 +349,17 @@ class Graph: public GraphInterface {
*/ */
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override; std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override;
/*! \brief Create an empty graph */
static MutableGraphPtr Create(bool multigraph = false) {
return std::make_shared<Graph>(multigraph);
}
/*! \brief Create from coo */
static MutableGraphPtr CreateFromCOO(
int64_t num_nodes, IdArray src_ids, IdArray dst_ids, bool multigraph = false) {
return std::make_shared<Graph>(src_ids, dst_ids, num_nodes, multigraph);
}
protected: protected:
friend class GraphOp; friend class GraphOp;
/*! \brief Internal edge list type */ /*! \brief Internal edge list type */
......
...@@ -11,13 +11,11 @@ ...@@ -11,13 +11,11 @@
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include "./runtime/object.h"
#include "array.h" #include "array.h"
namespace dgl { namespace dgl {
struct Subgraph;
struct NodeFlow;
const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1); const dgl_id_t DGL_INVALID_ID = static_cast<dgl_id_t>(-1);
/*! /*!
...@@ -51,6 +49,15 @@ class DGLIdIters { ...@@ -51,6 +49,15 @@ class DGLIdIters {
const dgl_id_t *begin_{nullptr}, *end_{nullptr}; const dgl_id_t *begin_{nullptr}, *end_{nullptr};
}; };
/* \brief structure used to represent a list of edges */
typedef struct {
/* \brief the two endpoints and the id of the edge */
IdArray src, dst, id;
} EdgeArray;
// forward declaration
struct Subgraph;
class GraphRef;
class GraphInterface; class GraphInterface;
typedef std::shared_ptr<GraphInterface> GraphPtr; typedef std::shared_ptr<GraphInterface> GraphPtr;
...@@ -58,15 +65,15 @@ typedef std::shared_ptr<GraphInterface> GraphPtr; ...@@ -58,15 +65,15 @@ typedef std::shared_ptr<GraphInterface> GraphPtr;
* \brief dgl graph index interface. * \brief dgl graph index interface.
* *
* DGL's graph is directed. Vertices are integers enumerated from zero. * DGL's graph is directed. Vertices are integers enumerated from zero.
*
* When calling functions supporing multiple edges (e.g. AddEdges, HasEdges),
* the input edges are represented by two id arrays for source and destination
* vertex ids. In the general case, the two arrays should have the same length.
* If the length of src id array is one, it represents one-many connections.
* If the length of dst id array is one, it represents many-one connections.
*/ */
class GraphInterface { class GraphInterface : public runtime::Object {
public: public:
/* \brief structure used to represent a list of edges */
typedef struct {
/* \brief the two endpoints and the id of the edge */
IdArray src, dst, id;
} EdgeArray;
virtual ~GraphInterface() = default; virtual ~GraphInterface() = default;
/*! /*!
...@@ -293,15 +300,6 @@ class GraphInterface { ...@@ -293,15 +300,6 @@ class GraphInterface {
*/ */
virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0; virtual Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const = 0;
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
*
* \return the reversed graph
*/
virtual GraphPtr Reverse() const = 0;
/*! /*!
* \brief Return the successor vector * \brief Return the successor vector
* \param vid The vertex id. * \param vid The vertex id.
...@@ -330,12 +328,6 @@ class GraphInterface { ...@@ -330,12 +328,6 @@ class GraphInterface {
*/ */
virtual DGLIdIters InEdgeVec(dgl_id_t vid) const = 0; virtual DGLIdIters InEdgeVec(dgl_id_t vid) const = 0;
/*!
* \brief Reset the data in the graph and move its data to the returned graph object.
* \return a raw pointer to the graph object.
*/
virtual GraphInterface *Reset() = 0;
/*! /*!
* \brief Get the adjacency matrix of the graph. * \brief Get the adjacency matrix of the graph.
* *
...@@ -353,8 +345,14 @@ class GraphInterface { ...@@ -353,8 +345,14 @@ class GraphInterface {
* \return a vector of IdArrays. * \return a vector of IdArrays.
*/ */
virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0; virtual std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const = 0;
static constexpr const char* _type_key = "graph.Graph";
DGL_DECLARE_OBJECT_TYPE_INFO(GraphInterface, runtime::Object);
}; };
// Define GraphRef
DGL_DEFINE_OBJECT_REF(GraphRef, GraphInterface);
/*! \brief Subgraph data structure */ /*! \brief Subgraph data structure */
struct Subgraph { struct Subgraph {
/*! \brief The graph. */ /*! \brief The graph. */
......
...@@ -14,6 +14,15 @@ namespace dgl { ...@@ -14,6 +14,15 @@ namespace dgl {
class GraphOp { class GraphOp {
public: public:
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
*
* \return the reversed graph
*/
static GraphPtr Reverse(GraphPtr graph);
/*! /*!
* \brief Return the line graph. * \brief Return the line graph.
* *
...@@ -25,7 +34,7 @@ class GraphOp { ...@@ -25,7 +34,7 @@ class GraphOp {
* \param backtracking Whether the backtracking edges are included or not * \param backtracking Whether the backtracking edges are included or not
* \return the line graph * \return the line graph
*/ */
static Graph LineGraph(const Graph* graph, bool backtracking); static GraphPtr LineGraph(GraphPtr graph, bool backtracking);
/*! /*!
* \brief Return a disjoint union of the input graphs. * \brief Return a disjoint union of the input graphs.
...@@ -36,10 +45,13 @@ class GraphOp { ...@@ -36,10 +45,13 @@ class GraphOp {
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7 * they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
* in the result graph. Edge ids are re-assigned similarly. * in the result graph. Edge ids are re-assigned similarly.
* *
* The input list must be either ALL mutable graphs or ALL immutable graphs.
* The returned graph type is also determined by the input graph type.
*
* \param graphs A list of input graphs to be unioned. * \param graphs A list of input graphs to be unioned.
* \return the disjoint union of the graphs * \return the disjoint union of the graphs
*/ */
static Graph DisjointUnion(std::vector<const Graph*> graphs); static GraphPtr DisjointUnion(std::vector<GraphPtr> graphs);
/*! /*!
* \brief Partition the graph into several subgraphs. * \brief Partition the graph into several subgraphs.
...@@ -48,11 +60,14 @@ class GraphOp { ...@@ -48,11 +60,14 @@ class GraphOp {
* into num graphs. This requires the given number of partitions to evenly * into num graphs. This requires the given number of partitions to evenly
* divides the number of nodes in the graph. * divides the number of nodes in the graph.
* *
* If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable.
*
* \param graph The graph to be partitioned. * \param graph The graph to be partitioned.
* \param num The number of partitions. * \param num The number of partitions.
* \return a list of partitioned graphs * \return a list of partitioned graphs
*/ */
static std::vector<Graph> DisjointPartitionByNum(const Graph* graph, int64_t num); static std::vector<GraphPtr> DisjointPartitionByNum(GraphPtr graph, int64_t num);
/*! /*!
* \brief Partition the graph into several subgraphs. * \brief Partition the graph into several subgraphs.
...@@ -61,53 +76,14 @@ class GraphOp { ...@@ -61,53 +76,14 @@ class GraphOp {
* based on the given sizes. This requires the sum of the given sizes is equal * based on the given sizes. This requires the sum of the given sizes is equal
* to the number of nodes in the graph. * to the number of nodes in the graph.
* *
* If the input graph is mutable, the result graphs are mutable.
* If the input graph is immutable, the result graphs are immutable.
*
* \param graph The graph to be partitioned. * \param graph The graph to be partitioned.
* \param sizes The number of partitions. * \param sizes The number of partitions.
* \return a list of partitioned graphs * \return a list of partitioned graphs
*/ */
static std::vector<Graph> DisjointPartitionBySizes(const Graph* graph, IdArray sizes); static std::vector<GraphPtr> DisjointPartitionBySizes(GraphPtr graph, IdArray sizes);
/*!
* \brief Return a readonly disjoint union of the input graphs.
*
* The new readonly graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled in the given sequence order by batching over CSR Graphs.
* For example, giving input [g1, g2, g3], where
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
* in the result graph. Edge ids are re-assigned similarly.
*
* \param ImmutableGraph A list of input graphs to be unioned.
* \return the disjoint union of the ImmutableGraph
*/
static ImmutableGraph DisjointUnion(std::vector<const ImmutableGraph*> graphs);
/*!
* \brief Partition the ImmutableGraph into several immutable subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* into num graphs. This requires the given number of partitions to evenly
* divides the number of nodes in the graph.
*
* \param graph The ImmutableGraph to be partitioned.
* \param num The number of partitions.
* \return a list of partitioned ImmutableGraph
*/
static std::vector<ImmutableGraph> DisjointPartitionByNum(const ImmutableGraph *graph,
int64_t num);
/*!
* \brief Partition the ImmutableGraph into several immutable subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* based on the given sizes. This requires the sum of the given sizes is equal
* to the number of nodes in the graph.
*
* \param graph The ImmutableGraph to be partitioned.
* \param sizes The number of partitions.
* \return a list of partitioned ImmutableGraph
*/
static std::vector<ImmutableGraph> DisjointPartitionBySizes(const ImmutableGraph *batched_graph,
IdArray sizes);
/*! /*!
* \brief Map vids in the parent graph to the vids in the subgraph. * \brief Map vids in the parent graph to the vids in the subgraph.
...@@ -143,7 +119,7 @@ class GraphOp { ...@@ -143,7 +119,7 @@ class GraphOp {
* \param graph The input graph. * \param graph The input graph.
* \return a new immutable simple graph with no multi-edge. * \return a new immutable simple graph with no multi-edge.
*/ */
static ImmutableGraph ToSimpleGraph(const GraphInterface* graph); static GraphPtr ToSimpleGraph(GraphPtr graph);
/*! /*!
* \brief Convert the graph to a mutable bidirected graph. * \brief Convert the graph to a mutable bidirected graph.
...@@ -155,14 +131,14 @@ class GraphOp { ...@@ -155,14 +131,14 @@ class GraphOp {
* \param graph The input graph. * \param graph The input graph.
* \return a new mutable bidirected graph. * \return a new mutable bidirected graph.
*/ */
static Graph ToBidirectedMutableGraph(const GraphInterface* graph); static GraphPtr ToBidirectedMutableGraph(GraphPtr graph);
/*! /*!
* \brief Same as BidirectedMutableGraph except that the returned graph is immutable. * \brief Same as BidirectedMutableGraph except that the returned graph is immutable.
* \param graph The input graph. * \param graph The input graph.
* \return a new immutable bidirected graph. * \return a new immutable bidirected graph.
*/ */
static ImmutableGraph ToBidirectedImmutableGraph(const GraphInterface* graph); static GraphPtr ToBidirectedImmutableGraph(GraphPtr graph);
}; };
} // namespace dgl } // namespace dgl
......
...@@ -38,14 +38,6 @@ typedef std::shared_ptr<HeteroGraphInterface> HeteroGraphPtr; ...@@ -38,14 +38,6 @@ typedef std::shared_ptr<HeteroGraphInterface> HeteroGraphPtr;
*/ */
class HeteroGraphInterface { class HeteroGraphInterface {
public: public:
/* \brief structure used to represent a list of edges */
// TODO(minjie): move this data structure outside of class definition so
// it can be shared by Graph and HeteroGraph.
typedef struct {
/* \brief the two endpoints and the id of the edge */
IdArray src, dst, id;
} EdgeArray;
virtual ~HeteroGraphInterface() = default; virtual ~HeteroGraphInterface() = default;
////////////////////////// query/operations on meta graph //////////////////////// ////////////////////////// query/operations on meta graph ////////////////////////
...@@ -112,7 +104,7 @@ class HeteroGraphInterface { ...@@ -112,7 +104,7 @@ class HeteroGraphInterface {
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/ /*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const { virtual BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const {
const auto len = vids->shape[0]; const auto len = vids->shape[0];
BoolArray rst = NewBoolArray(len); BoolArray rst = aten::NewBoolArray(len);
const dgl_id_t* vid_data = static_cast<dgl_id_t*>(vids->data); const dgl_id_t* vid_data = static_cast<dgl_id_t*>(vids->data);
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data); dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
...@@ -129,7 +121,7 @@ class HeteroGraphInterface { ...@@ -129,7 +121,7 @@ class HeteroGraphInterface {
const auto srclen = src_ids->shape[0]; const auto srclen = src_ids->shape[0];
const auto dstlen = dst_ids->shape[0]; const auto dstlen = dst_ids->shape[0];
const auto rstlen = std::max(srclen, dstlen); const auto rstlen = std::max(srclen, dstlen);
BoolArray rst = NewBoolArray(rstlen); BoolArray rst = aten::NewBoolArray(rstlen);
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data); dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data); const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data); const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data);
...@@ -341,12 +333,6 @@ class HeteroGraphInterface { ...@@ -341,12 +333,6 @@ class HeteroGraphInterface {
*/ */
virtual DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0; virtual DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const = 0;
/*!
* \brief Reset the data in the graph and move its data to the returned graph object.
* \return a raw pointer to the graph object.
*/
virtual HeteroGraphInterface *Reset() = 0;
/*! /*!
* \brief Get the adjacency matrix of the graph. * \brief Get the adjacency matrix of the graph.
* *
......
...@@ -23,6 +23,9 @@ class COO; ...@@ -23,6 +23,9 @@ class COO;
typedef std::shared_ptr<CSR> CSRPtr; typedef std::shared_ptr<CSR> CSRPtr;
typedef std::shared_ptr<COO> COOPtr; typedef std::shared_ptr<COO> COOPtr;
class ImmutableGraph;
typedef std::shared_ptr<ImmutableGraph> ImmutableGraphPtr;
/*! /*!
* \brief Graph class stored using CSR structure. * \brief Graph class stored using CSR structure.
*/ */
...@@ -167,10 +170,6 @@ class CSR : public GraphInterface { ...@@ -167,10 +170,6 @@ class CSR : public GraphInterface {
return {}; return {};
} }
GraphPtr Reverse() const override {
return Transpose();
}
DGLIdIters SuccVec(dgl_id_t vid) const override; DGLIdIters SuccVec(dgl_id_t vid) const override;
DGLIdIters OutEdgeVec(dgl_id_t vid) const override; DGLIdIters OutEdgeVec(dgl_id_t vid) const override;
...@@ -187,12 +186,6 @@ class CSR : public GraphInterface { ...@@ -187,12 +186,6 @@ class CSR : public GraphInterface {
return DGLIdIters(nullptr, nullptr); return DGLIdIters(nullptr, nullptr);
} }
GraphInterface *Reset() override {
CSR* gptr = new CSR();
*gptr = std::move(*this);
return gptr;
}
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override { std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
CHECK(!transpose && fmt == "csr") << "Not valid adj format request."; CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
return {adj_.indptr, adj_.indices, adj_.data}; return {adj_.indptr, adj_.indices, adj_.data};
...@@ -256,7 +249,7 @@ class CSR : public GraphInterface { ...@@ -256,7 +249,7 @@ class CSR : public GraphInterface {
aten::CSRMatrix adj_; aten::CSRMatrix adj_;
// whether the graph is a multi-graph // whether the graph is a multi-graph
LazyObject<bool> is_multigraph_; Lazy<bool> is_multigraph_;
// The name of the shared memory to store data. // The name of the shared memory to store data.
// If it's empty, data isn't stored in shared memory. // If it's empty, data isn't stored in shared memory.
...@@ -416,10 +409,6 @@ class COO : public GraphInterface { ...@@ -416,10 +409,6 @@ class COO : public GraphInterface {
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override; Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
GraphPtr Reverse() const override {
return Transpose();
}
DGLIdIters SuccVec(dgl_id_t vid) const override { DGLIdIters SuccVec(dgl_id_t vid) const override {
LOG(FATAL) << "COO graph does not support efficient SuccVec." LOG(FATAL) << "COO graph does not support efficient SuccVec."
<< " Please use CSR graph or AdjList graph instead."; << " Please use CSR graph or AdjList graph instead.";
...@@ -444,12 +433,6 @@ class COO : public GraphInterface { ...@@ -444,12 +433,6 @@ class COO : public GraphInterface {
return DGLIdIters(nullptr, nullptr); return DGLIdIters(nullptr, nullptr);
} }
GraphInterface *Reset() override {
COO* gptr = new COO();
*gptr = std::move(*this);
return gptr;
}
std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override { std::vector<IdArray> GetAdj(bool transpose, const std::string &fmt) const override {
CHECK(fmt == "coo") << "Not valid adj format request."; CHECK(fmt == "coo") << "Not valid adj format request.";
if (transpose) { if (transpose) {
...@@ -517,7 +500,7 @@ class COO : public GraphInterface { ...@@ -517,7 +500,7 @@ class COO : public GraphInterface {
aten::COOMatrix adj_; aten::COOMatrix adj_;
/*! \brief whether the graph is a multi-graph */ /*! \brief whether the graph is a multi-graph */
LazyObject<bool> is_multigraph_; Lazy<bool> is_multigraph_;
}; };
/*! /*!
...@@ -840,21 +823,6 @@ class ImmutableGraph: public GraphInterface { ...@@ -840,21 +823,6 @@ class ImmutableGraph: public GraphInterface {
*/ */
Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override; Subgraph EdgeSubgraph(IdArray eids, bool preserve_nodes = false) const override;
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
*
* \return the reversed graph
*/
GraphPtr Reverse() const override {
if (coo_) {
return GraphPtr(new ImmutableGraph(out_csr_, in_csr_, coo_->Transpose()));
} else {
return GraphPtr(new ImmutableGraph(out_csr_, in_csr_));
}
}
/*! /*!
* \brief Return the successor vector * \brief Return the successor vector
* \param vid The vertex id. * \param vid The vertex id.
...@@ -891,16 +859,6 @@ class ImmutableGraph: public GraphInterface { ...@@ -891,16 +859,6 @@ class ImmutableGraph: public GraphInterface {
return GetInCSR()->OutEdgeVec(vid); return GetInCSR()->OutEdgeVec(vid);
} }
/*!
* \brief Reset the data in the graph and move its data to the returned graph object.
* \return a raw pointer to the graph object.
*/
GraphInterface *Reset() override {
ImmutableGraph* gptr = new ImmutableGraph();
*gptr = std::move(*this);
return gptr;
}
/*! /*!
* \brief Get the adjacency matrix of the graph. * \brief Get the adjacency matrix of the graph.
* *
...@@ -921,6 +879,35 @@ class ImmutableGraph: public GraphInterface { ...@@ -921,6 +879,35 @@ class ImmutableGraph: public GraphInterface {
/* !\brief Return coo. If not exist, create from csr.*/ /* !\brief Return coo. If not exist, create from csr.*/
COOPtr GetCOO() const; COOPtr GetCOO() const;
/*! \brief Create an immutable graph from CSR. */
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir, const std::string &shared_mem_name);
static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name);
static ImmutableGraphPtr CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir);
/*! \brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst);
static ImmutableGraphPtr CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, bool multigraph);
/*! /*!
* \brief Convert the given graph to an immutable graph. * \brief Convert the given graph to an immutable graph.
* *
...@@ -930,14 +917,14 @@ class ImmutableGraph: public GraphInterface { ...@@ -930,14 +917,14 @@ class ImmutableGraph: public GraphInterface {
* \param graph The input graph. * \param graph The input graph.
* \return an immutable graph object. * \return an immutable graph object.
*/ */
static ImmutableGraph ToImmutable(const GraphInterface* graph); static ImmutableGraphPtr ToImmutable(GraphPtr graph);
/*! /*!
* \brief Copy the data to another context. * \brief Copy the data to another context.
* \param ctx The target context. * \param ctx The target context.
* \return The graph under another context. * \return The graph under another context.
*/ */
ImmutableGraph CopyTo(const DLContext& ctx) const; static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DLContext& ctx);
/*! /*!
* \brief Copy data to shared memory. * \brief Copy data to shared memory.
...@@ -945,85 +932,24 @@ class ImmutableGraph: public GraphInterface { ...@@ -945,85 +932,24 @@ class ImmutableGraph: public GraphInterface {
* \param name The name of the shared memory. * \param name The name of the shared memory.
* \return The graph in the shared memory * \return The graph in the shared memory
*/ */
ImmutableGraph CopyToSharedMem(const std::string &edge_dir, const std::string &name) const; static ImmutableGraphPtr CopyToSharedMem(
ImmutableGraphPtr g, const std::string &edge_dir, const std::string &name);
/*! /*!
* \brief Convert the graph to use the given number of bits for storage. * \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64). * \param bits The new number of integer bits (32 or 64).
* \return The graph with new bit size storage. * \return The graph with new bit size storage.
*/ */
ImmutableGraph AsNumBits(uint8_t bits) const; static ImmutableGraphPtr AsNumBits(ImmutableGraphPtr g, uint8_t bits);
/*! \brief Create an immutable graph from CSR. */ /*!
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids, * \brief Return a new graph with all the edges reversed.
const std::string &edge_dir) { *
CSRPtr csr(new CSR(indptr, indices, edge_ids)); * The returned graph preserves the vertex and edge index in the original graph.
if (edge_dir == "in") { *
return ImmutableGraph(csr, nullptr); * \return the reversed graph
} else if (edge_dir == "out") { */
return ImmutableGraph(nullptr, csr); ImmutableGraphPtr Reverse() const;
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
static ImmutableGraph CreateFromCSR(const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
multigraph));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
protected: protected:
/* !\brief internal default constructor */ /* !\brief internal default constructor */
...@@ -1041,10 +967,6 @@ class ImmutableGraph: public GraphInterface { ...@@ -1041,10 +967,6 @@ class ImmutableGraph: public GraphInterface {
this->shared_mem_name_ = shared_mem_name; this->shared_mem_name_ = shared_mem_name;
} }
static std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
return name + "_" + edge_dir;
}
/* !\brief return pointer to any available graph structure */ /* !\brief return pointer to any available graph structure */
GraphPtr AnyGraph() const { GraphPtr AnyGraph() const {
if (in_csr_) { if (in_csr_) {
......
...@@ -17,16 +17,16 @@ namespace dgl { ...@@ -17,16 +17,16 @@ namespace dgl {
* The object is currently not threaad safe. * The object is currently not threaad safe.
*/ */
template <typename T> template <typename T>
class LazyObject { class Lazy {
public: public:
/*!\brief default constructor to construct a lazy object */ /*!\brief default constructor to construct a lazy object */
LazyObject() {} Lazy() {}
/*!\brief constructor to construct an object with given value (non-lazy case) */ /*!\brief constructor to construct an object with given value (non-lazy case) */
explicit LazyObject(const T& val): ptr_(new T(val)) {} explicit Lazy(const T& val): ptr_(new T(val)) {}
/*!\brief destructor */ /*!\brief destructor */
~LazyObject() = default; ~Lazy() = default;
/*! /*!
* \brief Get the value of this object. If the object has not been instantiated, * \brief Get the value of this object. If the object has not been instantiated,
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include "./runtime/object.h"
#include "graph_interface.h" #include "graph_interface.h"
namespace dgl { namespace dgl {
...@@ -23,7 +24,7 @@ class ImmutableGraph; ...@@ -23,7 +24,7 @@ class ImmutableGraph;
* in a more compact format. We store extra information, * in a more compact format. We store extra information,
* such as the node and edge mapping from the NodeFlow graph to the parent graph. * such as the node and edge mapping from the NodeFlow graph to the parent graph.
*/ */
struct NodeFlow { struct NodeFlowObject : public runtime::Object {
/*! \brief The graph. */ /*! \brief The graph. */
GraphPtr graph; GraphPtr graph;
/*! /*!
...@@ -42,6 +43,20 @@ struct NodeFlow { ...@@ -42,6 +43,20 @@ struct NodeFlow {
* \brief The edge mapping from the NodeFlow graph to the parent graph. * \brief The edge mapping from the NodeFlow graph to the parent graph.
*/ */
IdArray edge_mapping; IdArray edge_mapping;
static constexpr const char* _type_key = "graph.NodeFlow";
DGL_DECLARE_OBJECT_TYPE_INFO(NodeFlowObject, runtime::Object);
};
// Define NodeFlow as the reference class of NodeFlowObject
class NodeFlow : public runtime::ObjectRef {
public:
DGL_DEFINE_OBJECT_REF_METHODS(NodeFlow, runtime::ObjectRef, NodeFlowObject);
/*! \brief create a new nodeflow reference */
static NodeFlow Create() {
return NodeFlow(std::make_shared<NodeFlowObject>());
}
}; };
/*! /*!
......
...@@ -314,6 +314,10 @@ class List : public ObjectRef { ...@@ -314,6 +314,10 @@ class List : public ObjectRef {
inline bool empty() const { inline bool empty() const {
return size() == 0; return size() == 0;
} }
/*! \brief Copy the content to a vector */
inline std::vector<T> ToVector() const {
return std::vector<T>(begin(), end());
}
/*! \brief specify container obj */ /*! \brief specify container obj */
using ContainerType = ListObject; using ContainerType = ListObject;
......
...@@ -192,10 +192,10 @@ class ObjectRef { ...@@ -192,10 +192,10 @@ class ObjectRef {
* For example: * For example:
* *
* // This class is an abstract class and cannot create instances * // This class is an abstract class and cannot create instances
* class SomeBaseClass : public Node { * class SomeBaseClass : public Object {
* public: * public:
* static constexpr const char* _type_key = "some_base"; * static constexpr const char* _type_key = "some_base";
* DGL_DECLARE_BASE_OBJECT_INFO(SomeBaseClass, Node); * DGL_DECLARE_BASE_OBJECT_INFO(SomeBaseClass, Object);
* }; * };
* *
* // Child class that allows instantiation * // Child class that allows instantiation
...@@ -219,6 +219,29 @@ class ObjectRef { ...@@ -219,6 +219,29 @@ class ObjectRef {
return Parent::_DerivedFrom(tid); \ return Parent::_DerivedFrom(tid); \
} }
/*! \brief Macro to generate common object reference class method definition */
#define DGL_DEFINE_OBJECT_REF_METHODS(TypeName, BaseTypeName, ObjectName) \
TypeName() {} \
explicit TypeName(std::shared_ptr<runtime::Object> obj): BaseTypeName(obj) {} \
const ObjectName* operator->() const { \
return static_cast<const ObjectName*>(obj_.get()); \
} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(obj_.get()); \
} \
std::shared_ptr<ObjectName> sptr() const { \
return CHECK_NOTNULL(std::dynamic_pointer_cast<ObjectName>(obj_)); \
} \
operator bool() const { return this->defined(); } \
using ContainerType = ObjectName;
/*! \brief Macro to generate object reference class definition */
#define DGL_DEFINE_OBJECT_REF(TypeName, ObjectName) \
class TypeName : public ::dgl::runtime::ObjectRef { \
public: \
DGL_DEFINE_OBJECT_REF_METHODS(TypeName, ::dgl::runtime::ObjectRef, ObjectName); \
};
// implementations of inline functions after this // implementations of inline functions after this
template<typename T> template<typename T>
inline bool Object::is_type() const { inline bool Object::is_type() const {
......
...@@ -61,11 +61,6 @@ class ObjectBase(object): ...@@ -61,11 +61,6 @@ class ObjectBase(object):
"'%s' object has no attribute '%s'" % (str(type(self)), name)) "'%s' object has no attribute '%s'" % (str(type(self)), name))
return RETURN_SWITCH[ret_type_code.value](ret_val) return RETURN_SWITCH[ret_type_code.value](ret_val)
def __setattr__(self, name, value):
if name != 'handle':
raise AttributeError('Set attribute is not allowed for DGL object.')
object.__setattr__(self, name, value)
def __init_handle_by_constructor__(self, fconstructor, *args): def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function. """Initialize the handle by calling constructor function.
......
...@@ -35,7 +35,7 @@ def random_walk(g, seeds, num_traces, num_hops): ...@@ -35,7 +35,7 @@ def random_walk(g, seeds, num_traces, num_hops):
if len(seeds) == 0: if len(seeds) == 0:
return utils.toindex([]).tousertensor() return utils.toindex([]).tousertensor()
seeds = utils.toindex(seeds).todgltensor() seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLRandomWalk(g._graph._handle, traces = _CAPI_DGLRandomWalk(g._graph,
seeds, int(num_traces), int(num_hops)) seeds, int(num_traces), int(num_hops))
return F.zerocopy_from_dlpack(traces.to_dlpack()) return F.zerocopy_from_dlpack(traces.to_dlpack())
...@@ -109,7 +109,7 @@ def random_walk_with_restart( ...@@ -109,7 +109,7 @@ def random_walk_with_restart(
return [] return []
seeds = utils.toindex(seeds).todgltensor() seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLRandomWalkWithRestart( traces = _CAPI_DGLRandomWalkWithRestart(
g._graph._handle, seeds, restart_prob, int(max_nodes_per_seed), g._graph, seeds, restart_prob, int(max_nodes_per_seed),
int(max_visit_counts), int(max_frequent_visited_nodes)) int(max_visit_counts), int(max_frequent_visited_nodes))
return _split_traces(traces) return _split_traces(traces)
...@@ -161,7 +161,7 @@ def bipartite_single_sided_random_walk_with_restart( ...@@ -161,7 +161,7 @@ def bipartite_single_sided_random_walk_with_restart(
return [] return []
seeds = utils.toindex(seeds).todgltensor() seeds = utils.toindex(seeds).todgltensor()
traces = _CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart( traces = _CAPI_DGLBipartiteSingleSidedRandomWalkWithRestart(
g._graph._handle, seeds, restart_prob, int(max_nodes_per_seed), g._graph, seeds, restart_prob, int(max_nodes_per_seed),
int(max_visit_counts), int(max_frequent_visited_nodes)) int(max_visit_counts), int(max_frequent_visited_nodes))
return _split_traces(traces) return _split_traces(traces)
......
...@@ -10,7 +10,6 @@ from ..._ffi.function import _init_api ...@@ -10,7 +10,6 @@ from ..._ffi.function import _init_api
from ... import utils from ... import utils
from ...nodeflow import NodeFlow from ...nodeflow import NodeFlow
from ... import backend as F from ... import backend as F
from ...utils import unwrap_to_ptr_list
try: try:
import Queue as queue import Queue as queue
...@@ -310,8 +309,8 @@ class NeighborSampler(NodeFlowSampler): ...@@ -310,8 +309,8 @@ class NeighborSampler(NodeFlowSampler):
self._neighbor_type = neighbor_type self._neighbor_type = neighbor_type
def fetch(self, current_nodeflow_index): def fetch(self, current_nodeflow_index):
handles = unwrap_to_ptr_list(_CAPI_UniformSampling( nfobjs = _CAPI_UniformSampling(
self.g.c_handle, self.g._graph,
self.seed_nodes.todgltensor(), self.seed_nodes.todgltensor(),
current_nodeflow_index, # start batch id current_nodeflow_index, # start batch id
self.batch_size, # batch size self.batch_size, # batch size
...@@ -319,8 +318,8 @@ class NeighborSampler(NodeFlowSampler): ...@@ -319,8 +318,8 @@ class NeighborSampler(NodeFlowSampler):
self._expand_factor, self._expand_factor,
self._num_hops, self._num_hops,
self._neighbor_type, self._neighbor_type,
self._add_self_loop)) self._add_self_loop)
nflows = [NodeFlow(self.g, hdl) for hdl in handles] nflows = [NodeFlow(self.g, obj) for obj in nfobjs]
return nflows return nflows
...@@ -395,15 +394,15 @@ class LayerSampler(NodeFlowSampler): ...@@ -395,15 +394,15 @@ class LayerSampler(NodeFlowSampler):
self._layer_sizes = utils.toindex(layer_sizes) self._layer_sizes = utils.toindex(layer_sizes)
def fetch(self, current_nodeflow_index): def fetch(self, current_nodeflow_index):
handles = unwrap_to_ptr_list(_CAPI_LayerSampling( nfobjs = _CAPI_LayerSampling(
self.g.c_handle, self.g._graph,
self.seed_nodes.todgltensor(), self.seed_nodes.todgltensor(),
current_nodeflow_index, # start batch id current_nodeflow_index, # start batch id
self.batch_size, # batch size self.batch_size, # batch size
self._num_workers, # num batches self._num_workers, # num batches
self._layer_sizes.todgltensor(), self._layer_sizes.todgltensor(),
self._neighbor_type)) self._neighbor_type)
nflows = [NodeFlow(self.g, hdl) for hdl in handles] nflows = [NodeFlow(self.g, obj) for obj in nfobjs]
return nflows return nflows
def create_full_nodeflow(g, num_layers, add_self_loop=False): def create_full_nodeflow(g, num_layers, add_self_loop=False):
......
...@@ -38,11 +38,6 @@ class DGLBaseGraph(object): ...@@ -38,11 +38,6 @@ class DGLBaseGraph(object):
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
@property
def c_handle(self):
"""The C handle for the graph."""
return self._graph._handle
def number_of_nodes(self): def number_of_nodes(self):
"""Return the number of nodes in the graph. """Return the number of nodes in the graph.
......
"""Module for graph index class definition.""" """Module for graph index class definition."""
from __future__ import absolute_import from __future__ import absolute_import
import ctypes
import numpy as np import numpy as np
import networkx as nx import networkx as nx
import scipy import scipy
from ._ffi.base import c_array from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError
from . import backend as F from . import backend as F
from . import utils from . import utils
GraphIndexHandle = ctypes.c_void_p
class BoolFlag(object): class BoolFlag(object):
"""Bool flag with unknown value""" """Bool flag with unknown value"""
BOOL_UNKNOWN = -1 BOOL_UNKNOWN = -1
BOOL_FALSE = 0 BOOL_FALSE = 0
BOOL_TRUE = 1 BOOL_TRUE = 1
class GraphIndex(object): @register_object('graph.Graph')
class GraphIndex(ObjectBase):
"""Graph index object. """Graph index object.
Parameters Note
---------- ----
handle : GraphIndexHandle Do not create GraphIndex directly, you can create graph index object using
Handler following functions:
- `dgl.graph_index.from_edge_list`
- `dgl.graph_index.from_scipy_sparse_matrix`
- `dgl.graph_index.from_networkx`
- `dgl.graph_index.from_shared_mem_csr_matrix`
- `dgl.graph_index.from_csr`
- `dgl.graph_index.from_coo`
""" """
def __init__(self, handle): def __new__(cls):
self._handle = handle obj = ObjectBase.__new__(cls)
self._multigraph = None # python-side cache of the flag obj._multigraph = None # python-side cache of the flag
self._readonly = None # python-side cache of the flag obj._readonly = None # python-side cache of the flag
self._cache = {} obj._cache = {}
return obj
def __del__(self):
"""Free this graph index object."""
if hasattr(self, '_handle'):
_CAPI_DGLGraphFree(self._handle)
def __getstate__(self): def __getstate__(self):
src, dst, _ = self.edges() src, dst, _ = self.edges()
...@@ -59,18 +60,14 @@ class GraphIndex(object): ...@@ -59,18 +60,14 @@ class GraphIndex(object):
self._readonly = readonly self._readonly = readonly
if multigraph is None: if multigraph is None:
multigraph = BoolFlag.BOOL_UNKNOWN multigraph = BoolFlag.BOOL_UNKNOWN
self._handle = _CAPI_DGLGraphCreate( self.__init_handle_by_constructor__(
_CAPI_DGLGraphCreate,
src.todgltensor(), src.todgltensor(),
dst.todgltensor(), dst.todgltensor(),
int(multigraph), int(multigraph),
int(num_nodes), int(num_nodes),
readonly) readonly)
@property
def handle(self):
"""Get the CAPI handle."""
return self._handle
def add_nodes(self, num): def add_nodes(self, num):
"""Add nodes. """Add nodes.
...@@ -79,7 +76,7 @@ class GraphIndex(object): ...@@ -79,7 +76,7 @@ class GraphIndex(object):
num : int num : int
Number of nodes to be added. Number of nodes to be added.
""" """
_CAPI_DGLGraphAddVertices(self._handle, int(num)) _CAPI_DGLGraphAddVertices(self, int(num))
self.clear_cache() self.clear_cache()
def add_edge(self, u, v): def add_edge(self, u, v):
...@@ -92,7 +89,7 @@ class GraphIndex(object): ...@@ -92,7 +89,7 @@ class GraphIndex(object):
v : int v : int
The dst node. The dst node.
""" """
_CAPI_DGLGraphAddEdge(self._handle, u, v) _CAPI_DGLGraphAddEdge(self, u, v)
self.clear_cache() self.clear_cache()
def add_edges(self, u, v): def add_edges(self, u, v):
...@@ -107,12 +104,12 @@ class GraphIndex(object): ...@@ -107,12 +104,12 @@ class GraphIndex(object):
""" """
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
_CAPI_DGLGraphAddEdges(self._handle, u_array, v_array) _CAPI_DGLGraphAddEdges(self, u_array, v_array)
self.clear_cache() self.clear_cache()
def clear(self): def clear(self):
"""Clear the graph.""" """Clear the graph."""
_CAPI_DGLGraphClear(self._handle) _CAPI_DGLGraphClear(self)
self.clear_cache() self.clear_cache()
def clear_cache(self): def clear_cache(self):
...@@ -128,7 +125,7 @@ class GraphIndex(object): ...@@ -128,7 +125,7 @@ class GraphIndex(object):
True if it is a multigraph, False otherwise. True if it is a multigraph, False otherwise.
""" """
if self._multigraph is None: if self._multigraph is None:
self._multigraph = bool(_CAPI_DGLGraphIsMultigraph(self._handle)) self._multigraph = bool(_CAPI_DGLGraphIsMultigraph(self))
return self._multigraph return self._multigraph
def is_readonly(self): def is_readonly(self):
...@@ -140,7 +137,7 @@ class GraphIndex(object): ...@@ -140,7 +137,7 @@ class GraphIndex(object):
True if it is a read-only graph, False otherwise. True if it is a read-only graph, False otherwise.
""" """
if self._readonly is None: if self._readonly is None:
self._readonly = bool(_CAPI_DGLGraphIsReadonly(self._handle)) self._readonly = bool(_CAPI_DGLGraphIsReadonly(self))
return self._readonly return self._readonly
def readonly(self, readonly_state=True): def readonly(self, readonly_state=True):
...@@ -165,7 +162,7 @@ class GraphIndex(object): ...@@ -165,7 +162,7 @@ class GraphIndex(object):
int int
The number of nodes The number of nodes
""" """
return _CAPI_DGLGraphNumVertices(self._handle) return _CAPI_DGLGraphNumVertices(self)
def number_of_edges(self): def number_of_edges(self):
"""Return the number of edges. """Return the number of edges.
...@@ -175,7 +172,7 @@ class GraphIndex(object): ...@@ -175,7 +172,7 @@ class GraphIndex(object):
int int
The number of edges The number of edges
""" """
return _CAPI_DGLGraphNumEdges(self._handle) return _CAPI_DGLGraphNumEdges(self)
def has_node(self, vid): def has_node(self, vid):
"""Return true if the node exists. """Return true if the node exists.
...@@ -190,7 +187,7 @@ class GraphIndex(object): ...@@ -190,7 +187,7 @@ class GraphIndex(object):
bool bool
True if the node exists, False otherwise. True if the node exists, False otherwise.
""" """
return bool(_CAPI_DGLGraphHasVertex(self._handle, int(vid))) return bool(_CAPI_DGLGraphHasVertex(self, int(vid)))
def has_nodes(self, vids): def has_nodes(self, vids):
"""Return true if the nodes exist. """Return true if the nodes exist.
...@@ -206,7 +203,7 @@ class GraphIndex(object): ...@@ -206,7 +203,7 @@ class GraphIndex(object):
0-1 array indicating existence 0-1 array indicating existence
""" """
vid_array = vids.todgltensor() vid_array = vids.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasVertices(self._handle, vid_array)) return utils.toindex(_CAPI_DGLGraphHasVertices(self, vid_array))
def has_edge_between(self, u, v): def has_edge_between(self, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
...@@ -223,7 +220,7 @@ class GraphIndex(object): ...@@ -223,7 +220,7 @@ class GraphIndex(object):
bool bool
True if the edge exists, False otherwise True if the edge exists, False otherwise
""" """
return bool(_CAPI_DGLGraphHasEdgeBetween(self._handle, int(u), int(v))) return bool(_CAPI_DGLGraphHasEdgeBetween(self, int(u), int(v)))
def has_edges_between(self, u, v): def has_edges_between(self, u, v):
"""Return true if the edge exists. """Return true if the edge exists.
...@@ -242,7 +239,7 @@ class GraphIndex(object): ...@@ -242,7 +239,7 @@ class GraphIndex(object):
""" """
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphHasEdgesBetween(self._handle, u_array, v_array)) return utils.toindex(_CAPI_DGLGraphHasEdgesBetween(self, u_array, v_array))
def predecessors(self, v, radius=1): def predecessors(self, v, radius=1):
"""Return the predecessors of the node. """Return the predecessors of the node.
...@@ -260,7 +257,7 @@ class GraphIndex(object): ...@@ -260,7 +257,7 @@ class GraphIndex(object):
Array of predecessors Array of predecessors
""" """
return utils.toindex(_CAPI_DGLGraphPredecessors( return utils.toindex(_CAPI_DGLGraphPredecessors(
self._handle, int(v), int(radius))) self, int(v), int(radius)))
def successors(self, v, radius=1): def successors(self, v, radius=1):
"""Return the successors of the node. """Return the successors of the node.
...@@ -278,7 +275,7 @@ class GraphIndex(object): ...@@ -278,7 +275,7 @@ class GraphIndex(object):
Array of successors Array of successors
""" """
return utils.toindex(_CAPI_DGLGraphSuccessors( return utils.toindex(_CAPI_DGLGraphSuccessors(
self._handle, int(v), int(radius))) self, int(v), int(radius)))
def edge_id(self, u, v): def edge_id(self, u, v):
"""Return the id array of all edges between u and v. """Return the id array of all edges between u and v.
...@@ -295,7 +292,7 @@ class GraphIndex(object): ...@@ -295,7 +292,7 @@ class GraphIndex(object):
utils.Index utils.Index
The edge id array. The edge id array.
""" """
return utils.toindex(_CAPI_DGLGraphEdgeId(self._handle, int(u), int(v))) return utils.toindex(_CAPI_DGLGraphEdgeId(self, int(u), int(v)))
def edge_ids(self, u, v): def edge_ids(self, u, v):
"""Return a triplet of arrays that contains the edge IDs. """Return a triplet of arrays that contains the edge IDs.
...@@ -318,7 +315,7 @@ class GraphIndex(object): ...@@ -318,7 +315,7 @@ class GraphIndex(object):
""" """
u_array = u.todgltensor() u_array = u.todgltensor()
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphEdgeIds(self._handle, u_array, v_array) edge_array = _CAPI_DGLGraphEdgeIds(self, u_array, v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1))
...@@ -344,7 +341,7 @@ class GraphIndex(object): ...@@ -344,7 +341,7 @@ class GraphIndex(object):
The edge ids. The edge ids.
""" """
eid_array = eid.todgltensor() eid_array = eid.todgltensor()
edge_array = _CAPI_DGLGraphFindEdges(self._handle, eid_array) edge_array = _CAPI_DGLGraphFindEdges(self, eid_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1))
...@@ -370,10 +367,10 @@ class GraphIndex(object): ...@@ -370,10 +367,10 @@ class GraphIndex(object):
The edge ids. The edge ids.
""" """
if len(v) == 1: if len(v) == 1:
edge_array = _CAPI_DGLGraphInEdges_1(self._handle, int(v[0])) edge_array = _CAPI_DGLGraphInEdges_1(self, int(v[0]))
else: else:
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphInEdges_2(self._handle, v_array) edge_array = _CAPI_DGLGraphInEdges_2(self, v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2))
...@@ -397,10 +394,10 @@ class GraphIndex(object): ...@@ -397,10 +394,10 @@ class GraphIndex(object):
The edge ids. The edge ids.
""" """
if len(v) == 1: if len(v) == 1:
edge_array = _CAPI_DGLGraphOutEdges_1(self._handle, int(v[0])) edge_array = _CAPI_DGLGraphOutEdges_1(self, int(v[0]))
else: else:
v_array = v.todgltensor() v_array = v.todgltensor()
edge_array = _CAPI_DGLGraphOutEdges_2(self._handle, v_array) edge_array = _CAPI_DGLGraphOutEdges_2(self, v_array)
src = utils.toindex(edge_array(0)) src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1)) dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2)) eid = utils.toindex(edge_array(2))
...@@ -430,7 +427,7 @@ class GraphIndex(object): ...@@ -430,7 +427,7 @@ class GraphIndex(object):
""" """
if order is None: if order is None:
order = "" order = ""
edge_array = _CAPI_DGLGraphEdges(self._handle, order) edge_array = _CAPI_DGLGraphEdges(self, order)
src = edge_array(0) src = edge_array(0)
dst = edge_array(1) dst = edge_array(1)
eid = edge_array(2) eid = edge_array(2)
...@@ -452,7 +449,7 @@ class GraphIndex(object): ...@@ -452,7 +449,7 @@ class GraphIndex(object):
int int
The in degree. The in degree.
""" """
return _CAPI_DGLGraphInDegree(self._handle, int(v)) return _CAPI_DGLGraphInDegree(self, int(v))
def in_degrees(self, v): def in_degrees(self, v):
"""Return the in degrees of the nodes. """Return the in degrees of the nodes.
...@@ -468,7 +465,7 @@ class GraphIndex(object): ...@@ -468,7 +465,7 @@ class GraphIndex(object):
The in degree array. The in degree array.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphInDegrees(self._handle, v_array)) return utils.toindex(_CAPI_DGLGraphInDegrees(self, v_array))
def out_degree(self, v): def out_degree(self, v):
"""Return the out degree of the node. """Return the out degree of the node.
...@@ -483,7 +480,7 @@ class GraphIndex(object): ...@@ -483,7 +480,7 @@ class GraphIndex(object):
int int
The out degree. The out degree.
""" """
return _CAPI_DGLGraphOutDegree(self._handle, int(v)) return _CAPI_DGLGraphOutDegree(self, int(v))
def out_degrees(self, v): def out_degrees(self, v):
"""Return the out degrees of the nodes. """Return the out degrees of the nodes.
...@@ -499,7 +496,7 @@ class GraphIndex(object): ...@@ -499,7 +496,7 @@ class GraphIndex(object):
The out degree array. The out degree array.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
return utils.toindex(_CAPI_DGLGraphOutDegrees(self._handle, v_array)) return utils.toindex(_CAPI_DGLGraphOutDegrees(self, v_array))
def node_subgraph(self, v): def node_subgraph(self, v):
"""Return the induced node subgraph. """Return the induced node subgraph.
...@@ -515,9 +512,9 @@ class GraphIndex(object): ...@@ -515,9 +512,9 @@ class GraphIndex(object):
The subgraph index. The subgraph index.
""" """
v_array = v.todgltensor() v_array = v.todgltensor()
rst = _CAPI_DGLGraphVertexSubgraph(self._handle, v_array) rst = _CAPI_DGLGraphVertexSubgraph(self, v_array)
induced_edges = utils.toindex(rst(2)) induced_edges = utils.toindex(rst(2))
gidx = GraphIndex(rst(0)) gidx = rst(0)
return SubgraphIndex(gidx, self, v, induced_edges) return SubgraphIndex(gidx, self, v, induced_edges)
def node_subgraphs(self, vs_arr): def node_subgraphs(self, vs_arr):
...@@ -556,9 +553,9 @@ class GraphIndex(object): ...@@ -556,9 +553,9 @@ class GraphIndex(object):
The subgraph index. The subgraph index.
""" """
e_array = e.todgltensor() e_array = e.todgltensor()
rst = _CAPI_DGLGraphEdgeSubgraph(self._handle, e_array, preserve_nodes) rst = _CAPI_DGLGraphEdgeSubgraph(self, e_array, preserve_nodes)
induced_nodes = utils.toindex(rst(1)) induced_nodes = utils.toindex(rst(1))
gidx = GraphIndex(rst(0)) gidx = rst(0)
return SubgraphIndex(gidx, self, induced_nodes, e) return SubgraphIndex(gidx, self, induced_nodes, e)
@utils.cached_member(cache='_cache', prefix='scipy_adj') @utils.cached_member(cache='_cache', prefix='scipy_adj')
...@@ -588,7 +585,7 @@ class GraphIndex(object): ...@@ -588,7 +585,7 @@ class GraphIndex(object):
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose))) ' but got %s.' % (type(transpose)))
rst = _CAPI_DGLGraphGetAdj(self._handle, transpose, fmt) rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr": if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy() indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy() indices = utils.toindex(rst(1)).tonumpy()
...@@ -631,9 +628,9 @@ class GraphIndex(object): ...@@ -631,9 +628,9 @@ class GraphIndex(object):
The first element of the tuple is the shuffle order for outward graph The first element of the tuple is the shuffle order for outward graph
The second element of the tuple is the shuffle order for inward graph The second element of the tuple is the shuffle order for inward graph
""" """
csr = _CAPI_DGLGraphGetAdj(self._handle, True, "csr") csr = _CAPI_DGLGraphGetAdj(self, True, "csr")
order = csr(2) order = csr(2)
rev_csr = _CAPI_DGLGraphGetAdj(self._handle, False, "csr") rev_csr = _CAPI_DGLGraphGetAdj(self, False, "csr")
rev_order = rev_csr(2) rev_order = rev_csr(2)
return utils.toindex(order), utils.toindex(rev_order) return utils.toindex(order), utils.toindex(rev_order)
...@@ -665,7 +662,7 @@ class GraphIndex(object): ...@@ -665,7 +662,7 @@ class GraphIndex(object):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose))) ' but got %s.' % (type(transpose)))
fmt = F.get_preferred_sparse_format() fmt = F.get_preferred_sparse_format()
rst = _CAPI_DGLGraphGetAdj(self._handle, transpose, fmt) rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr": if fmt == "csr":
indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx) indptr = F.copy_to(utils.toindex(rst(0)).tousertensor(), ctx)
indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx) indices = F.copy_to(utils.toindex(rst(1)).tousertensor(), ctx)
...@@ -794,8 +791,7 @@ class GraphIndex(object): ...@@ -794,8 +791,7 @@ class GraphIndex(object):
GraphIndex GraphIndex
The line graph of this graph. The line graph of this graph.
""" """
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking) return _CAPI_DGLGraphLineGraph(self, backtracking)
return GraphIndex(handle)
def to_immutable(self): def to_immutable(self):
"""Convert this graph index to an immutable one. """Convert this graph index to an immutable one.
...@@ -805,8 +801,7 @@ class GraphIndex(object): ...@@ -805,8 +801,7 @@ class GraphIndex(object):
GraphIndex GraphIndex
An immutable graph index. An immutable graph index.
""" """
handle = _CAPI_DGLToImmutable(self._handle) return _CAPI_DGLToImmutable(self)
return GraphIndex(handle)
def ctx(self): def ctx(self):
"""Return the context of this graph index. """Return the context of this graph index.
...@@ -816,7 +811,7 @@ class GraphIndex(object): ...@@ -816,7 +811,7 @@ class GraphIndex(object):
DGLContext DGLContext
The context of the graph. The context of the graph.
""" """
return _CAPI_DGLGraphContext(self._handle) return _CAPI_DGLGraphContext(self)
def copy_to(self, ctx): def copy_to(self, ctx):
"""Copy this immutable graph index to the given device context. """Copy this immutable graph index to the given device context.
...@@ -833,8 +828,7 @@ class GraphIndex(object): ...@@ -833,8 +828,7 @@ class GraphIndex(object):
GraphIndex GraphIndex
The graph index on the given device context. The graph index on the given device context.
""" """
handle = _CAPI_DGLImmutableGraphCopyTo(self._handle, ctx.device_type, ctx.device_id) return _CAPI_DGLImmutableGraphCopyTo(self, ctx.device_type, ctx.device_id)
return GraphIndex(handle)
def copyto_shared_mem(self, edge_dir, shared_mem_name): def copyto_shared_mem(self, edge_dir, shared_mem_name):
"""Copy this immutable graph index to shared memory. """Copy this immutable graph index to shared memory.
...@@ -853,8 +847,7 @@ class GraphIndex(object): ...@@ -853,8 +847,7 @@ class GraphIndex(object):
GraphIndex GraphIndex
The graph index on the given device context. The graph index on the given device context.
""" """
handle = _CAPI_DGLImmutableGraphCopyToSharedMem(self._handle, edge_dir, shared_mem_name) return _CAPI_DGLImmutableGraphCopyToSharedMem(self, edge_dir, shared_mem_name)
return GraphIndex(handle)
def nbits(self): def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64). """Return the number of integer bits used in the storage (32 or 64).
...@@ -864,7 +857,7 @@ class GraphIndex(object): ...@@ -864,7 +857,7 @@ class GraphIndex(object):
int int
The number of bits. The number of bits.
""" """
return _CAPI_DGLGraphNumBits(self._handle) return _CAPI_DGLGraphNumBits(self)
def bits_needed(self): def bits_needed(self):
"""Return the number of integer bits needed to represent the graph """Return the number of integer bits needed to represent the graph
...@@ -894,8 +887,7 @@ class GraphIndex(object): ...@@ -894,8 +887,7 @@ class GraphIndex(object):
GraphIndex GraphIndex
The graph index stored using the given number of bits. The graph index stored using the given number of bits.
""" """
handle = _CAPI_DGLImmutableGraphAsNumBits(self._handle, int(bits)) return _CAPI_DGLImmutableGraphAsNumBits(self, int(bits))
return GraphIndex(handle)
class SubgraphIndex(object): class SubgraphIndex(object):
"""Internal subgraph data structure. """Internal subgraph data structure.
...@@ -955,19 +947,17 @@ def from_coo(num_nodes, src, dst, is_multigraph, readonly): ...@@ -955,19 +947,17 @@ def from_coo(num_nodes, src, dst, is_multigraph, readonly):
if is_multigraph is None: if is_multigraph is None:
is_multigraph = BoolFlag.BOOL_UNKNOWN is_multigraph = BoolFlag.BOOL_UNKNOWN
if readonly: if readonly:
handle = _CAPI_DGLGraphCreate( gidx = _CAPI_DGLGraphCreate(
src.todgltensor(), src.todgltensor(),
dst.todgltensor(), dst.todgltensor(),
int(is_multigraph), int(is_multigraph),
int(num_nodes), int(num_nodes),
readonly) readonly)
gidx = GraphIndex(handle)
else: else:
if is_multigraph is BoolFlag.BOOL_UNKNOWN: if is_multigraph is BoolFlag.BOOL_UNKNOWN:
# TODO(minjie): better behavior in the future # TODO(minjie): better behavior in the future
is_multigraph = BoolFlag.BOOL_FALSE is_multigraph = BoolFlag.BOOL_FALSE
handle = _CAPI_DGLGraphCreateMutable(bool(is_multigraph)) gidx = _CAPI_DGLGraphCreateMutable(bool(is_multigraph))
gidx = GraphIndex(handle)
gidx.add_nodes(num_nodes) gidx.add_nodes(num_nodes)
gidx.add_edges(src, dst) gidx.add_edges(src, dst)
return gidx return gidx
...@@ -993,13 +983,13 @@ def from_csr(indptr, indices, is_multigraph, ...@@ -993,13 +983,13 @@ def from_csr(indptr, indices, is_multigraph,
indices = utils.toindex(indices) indices = utils.toindex(indices)
if is_multigraph is None: if is_multigraph is None:
is_multigraph = BoolFlag.BOOL_UNKNOWN is_multigraph = BoolFlag.BOOL_UNKNOWN
handle = _CAPI_DGLGraphCSRCreate( gidx = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(), indptr.todgltensor(),
indices.todgltensor(), indices.todgltensor(),
shared_mem_name, shared_mem_name,
int(is_multigraph), int(is_multigraph),
direction) direction)
return GraphIndex(handle) return gidx
def from_shared_mem_csr_matrix(shared_mem_name, def from_shared_mem_csr_matrix(shared_mem_name,
num_nodes, num_edges, edge_dir, num_nodes, num_edges, edge_dir,
...@@ -1017,12 +1007,12 @@ def from_shared_mem_csr_matrix(shared_mem_name, ...@@ -1017,12 +1007,12 @@ def from_shared_mem_csr_matrix(shared_mem_name,
edge_dir : string edge_dir : string
the edge direction. The supported option is "in" and "out". the edge direction. The supported option is "in" and "out".
""" """
handle = _CAPI_DGLGraphCSRCreateMMap( gidx = _CAPI_DGLGraphCSRCreateMMap(
shared_mem_name, shared_mem_name,
int(num_nodes), int(num_edges), int(num_nodes), int(num_edges),
is_multigraph, is_multigraph,
edge_dir) edge_dir)
return GraphIndex(handle) return gidx
def from_networkx(nx_graph, readonly): def from_networkx(nx_graph, readonly):
"""Convert from networkx graph. """Convert from networkx graph.
...@@ -1175,10 +1165,7 @@ def disjoint_union(graphs): ...@@ -1175,10 +1165,7 @@ def disjoint_union(graphs):
GraphIndex GraphIndex
The disjoint union The disjoint union
""" """
inputs = c_array(GraphIndexHandle, [gr._handle for gr in graphs]) return _CAPI_DGLDisjointUnion(list(graphs))
inputs = ctypes.cast(inputs, ctypes.c_void_p)
handle = _CAPI_DGLDisjointUnion(inputs, len(graphs))
return GraphIndex(handle)
def disjoint_partition(graph, num_or_size_splits): def disjoint_partition(graph, num_or_size_splits):
"""Partition the graph disjointly. """Partition the graph disjointly.
...@@ -1202,17 +1189,13 @@ def disjoint_partition(graph, num_or_size_splits): ...@@ -1202,17 +1189,13 @@ def disjoint_partition(graph, num_or_size_splits):
""" """
if isinstance(num_or_size_splits, utils.Index): if isinstance(num_or_size_splits, utils.Index):
rst = _CAPI_DGLDisjointPartitionBySizes( rst = _CAPI_DGLDisjointPartitionBySizes(
graph._handle, graph,
num_or_size_splits.todgltensor()) num_or_size_splits.todgltensor())
else: else:
rst = _CAPI_DGLDisjointPartitionByNum( rst = _CAPI_DGLDisjointPartitionByNum(
graph._handle, graph,
int(num_or_size_splits)) int(num_or_size_splits))
graphs = [] return rst
for val in rst.asnumpy():
handle = ctypes.cast(int(val), ctypes.c_void_p)
graphs.append(GraphIndex(handle))
return graphs
def create_graph_index(graph_data, multigraph, readonly): def create_graph_index(graph_data, multigraph, readonly):
"""Create a graph index object. """Create a graph index object.
...@@ -1236,8 +1219,7 @@ def create_graph_index(graph_data, multigraph, readonly): ...@@ -1236,8 +1219,7 @@ def create_graph_index(graph_data, multigraph, readonly):
raise Exception("can't create an empty immutable graph") raise Exception("can't create an empty immutable graph")
if multigraph is None: if multigraph is None:
multigraph = False multigraph = False
handle = _CAPI_DGLGraphCreateMutable(multigraph) return _CAPI_DGLGraphCreateMutable(multigraph)
return GraphIndex(handle)
elif isinstance(graph_data, (list, tuple)): elif isinstance(graph_data, (list, tuple)):
# edge list # edge list
return from_edge_list(graph_data, multigraph, readonly) return from_edge_list(graph_data, multigraph, readonly)
......
...@@ -139,7 +139,7 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out, ...@@ -139,7 +139,7 @@ def binary_op_reduce(reducer, op, G, A_target, B_target, A, B, out,
if out_rows is None: if out_rows is None:
out_rows = empty([]) out_rows = empty([])
_CAPI_DGLKernelBinaryOpReduce( _CAPI_DGLKernelBinaryOpReduce(
reducer, op, G._handle, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
A, B, out, A, B, out,
A_rows, B_rows, out_rows) A_rows, B_rows, out_rows)
...@@ -203,7 +203,7 @@ def backward_lhs_binary_op_reduce( ...@@ -203,7 +203,7 @@ def backward_lhs_binary_op_reduce(
if out_rows is None: if out_rows is None:
out_rows = empty([]) out_rows = empty([])
_CAPI_DGLKernelBackwardLhsBinaryOpReduce( _CAPI_DGLKernelBackwardLhsBinaryOpReduce(
reducer, op, G._handle, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
A_rows, B_rows, out_rows, A_rows, B_rows, out_rows,
A, B, out, A, B, out,
...@@ -268,7 +268,7 @@ def backward_rhs_binary_op_reduce( ...@@ -268,7 +268,7 @@ def backward_rhs_binary_op_reduce(
if out_rows is None: if out_rows is None:
out_rows = empty([]) out_rows = empty([])
_CAPI_DGLKernelBackwardRhsBinaryOpReduce( _CAPI_DGLKernelBackwardRhsBinaryOpReduce(
reducer, op, G._handle, reducer, op, G,
int(A_target), int(B_target), int(A_target), int(B_target),
A_rows, B_rows, out_rows, A_rows, B_rows, out_rows,
A, B, out, A, B, out,
...@@ -365,7 +365,7 @@ def copy_reduce(reducer, G, target, ...@@ -365,7 +365,7 @@ def copy_reduce(reducer, G, target,
if out_rows is None: if out_rows is None:
out_rows = empty([]) out_rows = empty([])
_CAPI_DGLKernelCopyReduce( _CAPI_DGLKernelCopyReduce(
reducer, G._handle, int(target), reducer, G, int(target),
X, out, X_rows, out_rows) X, out, X_rows, out_rows)
# pylint: disable=invalid-name # pylint: disable=invalid-name
...@@ -407,7 +407,7 @@ def backward_copy_reduce(reducer, G, target, ...@@ -407,7 +407,7 @@ def backward_copy_reduce(reducer, G, target,
if out_rows is None: if out_rows is None:
out_rows = empty([]) out_rows = empty([])
_CAPI_DGLKernelBackwardCopyReduce( _CAPI_DGLKernelBackwardCopyReduce(
reducer, G._handle, int(target), reducer, G, int(target),
X, out, grad_out, grad_X, X, out, grad_out, grad_X,
X_rows, out_rows) X_rows, out_rows)
......
...@@ -3,7 +3,6 @@ from __future__ import absolute_import ...@@ -3,7 +3,6 @@ from __future__ import absolute_import
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .nodeflow import NodeFlow from .nodeflow import NodeFlow
from .utils import unwrap_to_ptr_list
from . import utils from . import utils
_init_api("dgl.network") _init_api("dgl.network")
...@@ -64,14 +63,14 @@ def _send_nodeflow(sender, nodeflow, recv_id): ...@@ -64,14 +63,14 @@ def _send_nodeflow(sender, nodeflow, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
graph_handle = nodeflow._graph._handle gidx = nodeflow._graph
node_mapping = nodeflow._node_mapping.todgltensor() node_mapping = nodeflow._node_mapping.todgltensor()
edge_mapping = nodeflow._edge_mapping.todgltensor() edge_mapping = nodeflow._edge_mapping.todgltensor()
layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor() layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor()
flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor() flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor()
_CAPI_SenderSendSubgraph(sender, _CAPI_SenderSendSubgraph(sender,
int(recv_id), int(recv_id),
graph_handle, gidx,
node_mapping, node_mapping,
edge_mapping, edge_mapping,
layers_offsets, layers_offsets,
...@@ -137,5 +136,5 @@ def _recv_nodeflow(receiver, graph): ...@@ -137,5 +136,5 @@ def _recv_nodeflow(receiver, graph):
else: else:
raise RuntimeError('Got unexpected control code {}'.format(res)) raise RuntimeError('Got unexpected control code {}'.format(res))
else: else:
hdl = unwrap_to_ptr_list(res) # res is of type List<NodeFlowObject>
return NodeFlow(graph, hdl[0]) return NodeFlow(graph, res[0])
"""Class for NodeFlow data structure.""" """Class for NodeFlow data structure."""
from __future__ import absolute_import from __future__ import absolute_import
import ctypes from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import ALL, is_all, DGLError from .base import ALL, is_all, DGLError
from . import backend as F from . import backend as F
from .frame import Frame, FrameRef from .frame import Frame, FrameRef
from .graph import DGLBaseGraph from .graph import DGLBaseGraph
from .graph_index import GraphIndex, transform_ids from .graph_index import transform_ids
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime
from . import utils from . import utils
from .view import LayerView, BlockView from .view import LayerView, BlockView
__all__ = ['NodeFlow'] __all__ = ['NodeFlow']
NodeFlowHandle = ctypes.c_void_p @register_object('graph.NodeFlow')
class NodeFlowObject(ObjectBase):
"""NodeFlow object"""
@property
def graph(self):
"""The graph structure of this nodeflow.
Returns
-------
GraphIndex
"""
return _CAPI_NodeFlowGetGraph(self)
@property
def layer_offsets(self):
"""The offsets of each layer.
Returns
-------
NDArray
"""
return _CAPI_NodeFlowGetLayerOffsets(self)
@property
def block_offsets(self):
"""The offsets of each block.
Returns
-------
NDArray
"""
return _CAPI_NodeFlowGetBlockOffsets(self)
@property
def node_mapping(self):
"""Mapping array from nodeflow node id to parent graph
Returns
-------
NDArray
"""
return _CAPI_NodeFlowGetNodeMapping(self)
@property
def edge_mapping(self):
"""Mapping array from nodeflow edge id to parent graph
Returns
-------
NDArray
"""
return _CAPI_NodeFlowGetEdgeMapping(self)
class NodeFlow(DGLBaseGraph): class NodeFlow(DGLBaseGraph):
"""The NodeFlow class stores the sampling results of Neighbor """The NodeFlow class stores the sampling results of Neighbor
...@@ -36,25 +87,16 @@ class NodeFlow(DGLBaseGraph): ...@@ -36,25 +87,16 @@ class NodeFlow(DGLBaseGraph):
---------- ----------
parent : DGLGraph parent : DGLGraph
The parent graph. The parent graph.
handle : NodeFlowHandle nfobj : NodeFlowObject
The handle to the underlying C structure. The nodeflow object
""" """
def __init__(self, parent, handle): def __init__(self, parent, nfobj):
# NOTE(minjie): handle is a pointer to the underlying C++ structure super(NodeFlow, self).__init__(nfobj.graph)
# defined in include/dgl/sampler.h. The constructor will save
# all its members in the python side and destroy the handler
# afterwards. One can view the given handle object as a transient
# argument pack to construct this python class.
# TODO(minjie): We should use TVM's Node system as a cleaner solution later.
super(NodeFlow, self).__init__(GraphIndex(_CAPI_NodeFlowGetGraph(handle)))
self._parent = parent self._parent = parent
self._node_mapping = utils.toindex(_CAPI_NodeFlowGetNodeMapping(handle)) self._node_mapping = utils.toindex(nfobj.node_mapping)
self._edge_mapping = utils.toindex(_CAPI_NodeFlowGetEdgeMapping(handle)) self._edge_mapping = utils.toindex(nfobj.edge_mapping)
self._layer_offsets = utils.toindex( self._layer_offsets = utils.toindex(nfobj.layer_offsets).tonumpy()
_CAPI_NodeFlowGetLayerOffsets(handle)).tonumpy() self._block_offsets = utils.toindex(nfobj.block_offsets).tonumpy()
self._block_offsets = utils.toindex(
_CAPI_NodeFlowGetBlockOffsets(handle)).tonumpy()
_CAPI_NodeFlowFree(handle)
# node/edge frames # node/edge frames
self._node_frames = [FrameRef(Frame(num_rows=self.layer_size(i))) \ self._node_frames = [FrameRef(Frame(num_rows=self.layer_size(i))) \
for i in range(self.num_layers)] for i in range(self.num_layers)]
...@@ -293,6 +335,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -293,6 +335,7 @@ class NodeFlow(DGLBaseGraph):
The parent node id array. The parent node id array.
""" """
nid = utils.toindex(nid) nid = utils.toindex(nid)
# TODO(minjie): should not directly use []
return self._node_mapping.tousertensor()[nid.tousertensor()] return self._node_mapping.tousertensor()[nid.tousertensor()]
def map_to_parent_eid(self, eid): def map_to_parent_eid(self, eid):
...@@ -309,6 +352,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -309,6 +352,7 @@ class NodeFlow(DGLBaseGraph):
The parent edge id array. The parent edge id array.
""" """
eid = utils.toindex(eid) eid = utils.toindex(eid)
# TODO(minjie): should not directly use []
return self._edge_mapping.tousertensor()[eid.tousertensor()] return self._edge_mapping.tousertensor()[eid.tousertensor()]
def map_from_parent_nid(self, layer_id, parent_nids, remap_local=False): def map_from_parent_nid(self, layer_id, parent_nids, remap_local=False):
...@@ -418,6 +462,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -418,6 +462,7 @@ class NodeFlow(DGLBaseGraph):
assert layer_id + 1 < len(self._layer_offsets) assert layer_id + 1 < len(self._layer_offsets)
start = self._layer_offsets[layer_id] start = self._layer_offsets[layer_id]
end = self._layer_offsets[layer_id + 1] end = self._layer_offsets[layer_id + 1]
# TODO(minjie): should not directly use []
return self._node_mapping.tousertensor()[start:end] return self._node_mapping.tousertensor()[start:end]
def block_eid(self, block_id): def block_eid(self, block_id):
...@@ -456,6 +501,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -456,6 +501,7 @@ class NodeFlow(DGLBaseGraph):
block_id = self._get_block_id(block_id) block_id = self._get_block_id(block_id)
start = self._block_offsets[block_id] start = self._block_offsets[block_id]
end = self._block_offsets[block_id + 1] end = self._block_offsets[block_id + 1]
# TODO(minjie): should not directly use []
ret = self._edge_mapping.tousertensor()[start:end] ret = self._edge_mapping.tousertensor()[start:end]
# If `add_self_loop` is enabled, the returned parent eid can be -1. # If `add_self_loop` is enabled, the returned parent eid can be -1.
# We have to make sure this case doesn't happen. # We have to make sure this case doesn't happen.
...@@ -487,7 +533,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -487,7 +533,7 @@ class NodeFlow(DGLBaseGraph):
""" """
block_id = self._get_block_id(block_id) block_id = self._get_block_id(block_id)
layer0_size = self._layer_offsets[block_id + 1] - self._layer_offsets[block_id] layer0_size = self._layer_offsets[block_id + 1] - self._layer_offsets[block_id]
rst = _CAPI_NodeFlowGetBlockAdj(self._graph._handle, "coo", rst = _CAPI_NodeFlowGetBlockAdj(self._graph, "coo",
int(layer0_size), int(layer0_size),
int(self._layer_offsets[block_id + 1]), int(self._layer_offsets[block_id + 1]),
int(self._layer_offsets[block_id + 2]), int(self._layer_offsets[block_id + 2]),
...@@ -523,7 +569,7 @@ class NodeFlow(DGLBaseGraph): ...@@ -523,7 +569,7 @@ class NodeFlow(DGLBaseGraph):
fmt = F.get_preferred_sparse_format() fmt = F.get_preferred_sparse_format()
# We need to extract two layers. # We need to extract two layers.
layer0_size = self._layer_offsets[block_id + 1] - self._layer_offsets[block_id] layer0_size = self._layer_offsets[block_id + 1] - self._layer_offsets[block_id]
rst = _CAPI_NodeFlowGetBlockAdj(self._graph._handle, fmt, rst = _CAPI_NodeFlowGetBlockAdj(self._graph, fmt,
int(layer0_size), int(layer0_size),
int(self._layer_offsets[block_id + 1]), int(self._layer_offsets[block_id + 1]),
int(self._layer_offsets[block_id + 2]), int(self._layer_offsets[block_id + 2]),
......
"""Module for graph transformation methods.""" """Module for graph transformation methods."""
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .graph import DGLGraph from .graph import DGLGraph
from .graph_index import GraphIndex
from .batched_graph import BatchedDGLGraph from .batched_graph import BatchedDGLGraph
__all__ = ['line_graph', 'reverse', 'to_simple_graph', 'to_bidirected'] __all__ = ['line_graph', 'reverse', 'to_simple_graph', 'to_bidirected']
...@@ -121,8 +120,8 @@ def to_simple_graph(g): ...@@ -121,8 +120,8 @@ def to_simple_graph(g):
DGLGraph DGLGraph
A simple graph. A simple graph.
""" """
newgidx = GraphIndex(_CAPI_DGLToSimpleGraph(g._graph.handle)) gidx = _CAPI_DGLToSimpleGraph(g._graph)
return DGLGraph(newgidx, readonly=True) return DGLGraph(gidx, readonly=True)
def to_bidirected(g, readonly=True): def to_bidirected(g, readonly=True):
"""Convert the graph to a bidirected graph. """Convert the graph to a bidirected graph.
...@@ -165,9 +164,9 @@ def to_bidirected(g, readonly=True): ...@@ -165,9 +164,9 @@ def to_bidirected(g, readonly=True):
(tensor([0, 1, 1, 0, 0]), tensor([0, 0, 0, 1, 1])) (tensor([0, 1, 1, 0, 0]), tensor([0, 0, 0, 1, 1]))
""" """
if readonly: if readonly:
newgidx = GraphIndex(_CAPI_DGLToBidirectedImmutableGraph(g._graph.handle)) newgidx = _CAPI_DGLToBidirectedImmutableGraph(g._graph)
else: else:
newgidx = GraphIndex(_CAPI_DGLToBidirectedMutableGraph(g._graph.handle)) newgidx = _CAPI_DGLToBidirectedMutableGraph(g._graph)
return DGLGraph(newgidx) return DGLGraph(newgidx)
_init_api("dgl.transform") _init_api("dgl.transform")
...@@ -39,9 +39,9 @@ def bfs_nodes_generator(graph, source, reverse=False): ...@@ -39,9 +39,9 @@ def bfs_nodes_generator(graph, source, reverse=False):
>>> list(dgl.bfs_nodes_generator(g, 0)) >>> list(dgl.bfs_nodes_generator(g, 0))
[tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])] [tensor([0]), tensor([1]), tensor([2, 3]), tensor([4, 5])]
""" """
ghandle = graph._graph._handle gidx = graph._graph
source = utils.toindex(source) source = utils.toindex(source)
ret = _CAPI_DGLBFSNodes(ghandle, source.todgltensor(), reverse) ret = _CAPI_DGLBFSNodes(gidx, source.todgltensor(), reverse)
all_nodes = utils.toindex(ret(0)).tousertensor() all_nodes = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
...@@ -79,9 +79,9 @@ def bfs_edges_generator(graph, source, reverse=False): ...@@ -79,9 +79,9 @@ def bfs_edges_generator(graph, source, reverse=False):
>>> list(dgl.bfs_edges_generator(g, 0)) >>> list(dgl.bfs_edges_generator(g, 0))
[tensor([0]), tensor([1, 2]), tensor([4, 5])] [tensor([0]), tensor([1, 2]), tensor([4, 5])]
""" """
ghandle = graph._graph._handle gidx = graph._graph
source = utils.toindex(source) source = utils.toindex(source)
ret = _CAPI_DGLBFSEdges(ghandle, source.todgltensor(), reverse) ret = _CAPI_DGLBFSEdges(gidx, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
...@@ -116,8 +116,8 @@ def topological_nodes_generator(graph, reverse=False): ...@@ -116,8 +116,8 @@ def topological_nodes_generator(graph, reverse=False):
>>> list(dgl.topological_nodes_generator(g)) >>> list(dgl.topological_nodes_generator(g))
[tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])] [tensor([0]), tensor([1]), tensor([2]), tensor([3, 4]), tensor([5])]
""" """
ghandle = graph._graph._handle gidx = graph._graph
ret = _CAPI_DGLTopologicalNodes(ghandle, reverse) ret = _CAPI_DGLTopologicalNodes(gidx, reverse)
all_nodes = utils.toindex(ret(0)).tousertensor() all_nodes = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
...@@ -160,9 +160,9 @@ def dfs_edges_generator(graph, source, reverse=False): ...@@ -160,9 +160,9 @@ def dfs_edges_generator(graph, source, reverse=False):
>>> list(dgl.dfs_edges_generator(g, 0)) >>> list(dgl.dfs_edges_generator(g, 0))
[tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])] [tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4])]
""" """
ghandle = graph._graph._handle gidx = graph._graph
source = utils.toindex(source) source = utils.toindex(source)
ret = _CAPI_DGLDFSEdges(ghandle, source.todgltensor(), reverse) ret = _CAPI_DGLDFSEdges(gidx, source.todgltensor(), reverse)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tonumpy().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
...@@ -231,10 +231,10 @@ def dfs_labeled_edges_generator( ...@@ -231,10 +231,10 @@ def dfs_labeled_edges_generator(
(tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])), (tensor([0]), tensor([1]), tensor([3]), tensor([5]), tensor([4]), tensor([2])),
(tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2])) (tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([0]), tensor([2]))
""" """
ghandle = graph._graph._handle gidx = graph._graph
source = utils.toindex(source) source = utils.toindex(source)
ret = _CAPI_DGLDFSLabeledEdges( ret = _CAPI_DGLDFSLabeledEdges(
ghandle, gidx,
source.todgltensor(), source.todgltensor(),
reverse, reverse,
has_reverse_edge, has_reverse_edge,
......
"""Utility module.""" """Utility module."""
from __future__ import absolute_import, division from __future__ import absolute_import, division
import ctypes
from collections.abc import Mapping, Iterable from collections.abc import Mapping, Iterable
from functools import wraps from functools import wraps
import numpy as np import numpy as np
from . import _api_internal
from .base import DGLError from .base import DGLError
from . import backend as F from . import backend as F
from . import ndarray as nd from . import ndarray as nd
...@@ -534,30 +532,6 @@ def get_edata_name(g, name): ...@@ -534,30 +532,6 @@ def get_edata_name(g, name):
name += '_' name += '_'
return name return name
def unwrap_to_ptr_list(wrapper):
"""Convert the internal vector wrapper to a python list of ctypes.c_void_p.
The wrapper will be destroyed after this function.
Parameters
----------
wrapper : ctypes.c_void_p
The handler to the wrapper.
Returns
-------
list of ctypes.c_void_p
A python list of void pointers.
"""
size = _api_internal._GetVectorWrapperSize(wrapper)
if size == 0:
return []
data = _api_internal._GetVectorWrapperData(wrapper)
data = ctypes.cast(data, ctypes.POINTER(ctypes.c_void_p * size))
rst = [ctypes.c_void_p(x) for x in data.contents]
_api_internal._FreeVectorWrapper(wrapper)
return rst
def to_dgl_context(ctx): def to_dgl_context(ctx):
"""Convert a backend context to DGLContext""" """Convert a backend context to DGLContext"""
device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)] device_type = nd.DGLContext.STR2MASK[F.device_type(ctx)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment