Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
#ifndef DGL_ARRAY_KERNEL_DECL_H_ #ifndef DGL_ARRAY_KERNEL_DECL_H_
#define DGL_ARRAY_KERNEL_DECL_H_ #define DGL_ARRAY_KERNEL_DECL_H_
#include <dgl/bcast.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/bcast.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <string> #include <string>
#include <vector>
#include <utility> #include <utility>
#include <vector>
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -21,170 +21,129 @@ namespace aten { ...@@ -21,170 +21,129 @@ namespace aten {
* @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format. * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce, void SpMMCsr(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const aten::CSRMatrix& csr, const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, std::vector<NDArray> out_aux);
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux);
/** /**
* @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format
with heterograph support. * with heterograph support.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce, void SpMMCsrHetero(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
const std::vector<NDArray>& efeat, std::vector<std::vector<NDArray>>* out_aux,
std::vector<NDArray>* out, const std::vector<dgl_type_t>& ufeat_eid,
std::vector<std::vector<NDArray>>* out_aux, const std::vector<dgl_type_t>& out_eid);
const std::vector<dgl_type_t>& ufeat_eid,
const std::vector<dgl_type_t>& out_eid);
/** /**
* @brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format. * @brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const aten::COOMatrix& coo, const aten::COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, std::vector<NDArray> out_aux);
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux);
/** /**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format. * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
const aten::CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target);
/** /**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format
format with heterograph support. * with heterograph support.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op, void SDDMMCsrHetero(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_lhs, const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
const std::vector<NDArray>& vec_rhs, int lhs_target, int rhs_target, const std::vector<dgl_type_t>& ufeat_eid,
std::vector<NDArray> vec_out, const std::vector<dgl_type_t>& out_eid);
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& ufeat_eid,
const std::vector<dgl_type_t>& out_eid);
/** /**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format. * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const aten::COOMatrix& coo,
const aten::COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target);
/** /**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo * @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format
format with heterograph support. * with heterograph support.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op, void SDDMMCooHetero(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_lhs, const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
const std::vector<NDArray>& vec_rhs, int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
std::vector<NDArray> vec_out, const std::vector<dgl_type_t>& rhs_eid);
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid);
/** /**
* @brief Generalized Dense Matrix-Matrix Multiplication according to relation types. * @brief Generalized Dense Matrix-Matrix Multiplication according to relation
* types.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A, void GatherMM(
const NDArray B, const NDArray A, const NDArray B, NDArray out, const NDArray idx_a,
NDArray out, const NDArray idx_b);
const NDArray idx_a,
const NDArray idx_b);
/** /**
* @brief Generalized Dense Matrix-Matrix Multiplication according to relation types. * @brief Generalized Dense Matrix-Matrix Multiplication according to relation
* types.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A, void GatherMMScatter(
const NDArray B, const NDArray A, const NDArray B, NDArray out, const NDArray idx_a,
NDArray out, const NDArray idx_b, const NDArray idx_c);
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c);
/** /**
* @brief Generalized segmented dense Matrix-Matrix Multiplication. * @brief Generalized segmented dense Matrix-Matrix Multiplication.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A, void SegmentMM(
const NDArray B, const NDArray A, const NDArray B, NDArray out, const NDArray seglen_A,
NDArray out, bool a_trans, bool b_trans);
const NDArray seglen_A,
bool a_trans, bool b_trans);
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A, void SegmentMMBackwardB(
const NDArray dC, const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
NDArray dB,
const NDArray seglen);
/** /**
* @brief Segment reduce. * @brief Segment reduce.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op, void SegmentReduce(
NDArray feat, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray offsets, NDArray arg);
NDArray out,
NDArray arg);
/** /**
* @brief Scatter Add on first dimension. * @brief Scatter Add on first dimension.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat, void ScatterAdd(NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
/** /**
* @brief Update gradients for reduce operator max and min on first dimension. * @brief Update gradients for reduce operator max and min on first dimension.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, void UpdateGradMinMax_hetero(
const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx, const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out);
/** /**
* @brief Backward function of segment cmp. * @brief Backward function of segment cmp.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
/** /**
* @brief Sparse-sparse matrix multiplication * @brief Sparse-sparse matrix multiplication
...@@ -200,9 +159,7 @@ void BackwardSegmentCmp(NDArray feat, ...@@ -200,9 +159,7 @@ void BackwardSegmentCmp(NDArray feat,
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM( std::pair<CSRMatrix, NDArray> CSRMM(
const CSRMatrix& A, const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
NDArray A_weights,
const CSRMatrix& B,
NDArray B_weights); NDArray B_weights);
/** /**
...@@ -217,29 +174,22 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -217,29 +174,22 @@ std::pair<CSRMatrix, NDArray> CSRMM(
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum( std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A, const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);
const std::vector<NDArray>& A_weights);
/** /**
* @brief Edge_softmax_csr forward function on Csr format. * @brief Edge_softmax_csr forward function on Csr format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op, void Edge_softmax_csr_forward(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out);
NDArray ufeat,
NDArray efeat,
NDArray out);
/** /**
* @brief Edge_softmax_csr backward function on Csr format. * @brief Edge_softmax_csr backward function on Csr format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op, void Edge_softmax_csr_backward(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out);
NDArray ufeat,
NDArray efeat,
NDArray out);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -7,15 +7,17 @@ ...@@ -7,15 +7,17 @@
#ifndef DGL_GRAPH_HETEROGRAPH_H_ #ifndef DGL_GRAPH_HETEROGRAPH_H_
#define DGL_GRAPH_HETEROGRAPH_H_ #define DGL_GRAPH_HETEROGRAPH_H_
#include <dgl/runtime/shared_mem.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/lazy.h> #include <dgl/lazy.h>
#include <utility> #include <dgl/runtime/shared_mem.h>
#include <string>
#include <vector> #include <memory>
#include <set> #include <set>
#include <string>
#include <tuple> #include <tuple>
#include <memory> #include <utility>
#include <vector>
#include "./unit_graph.h" #include "./unit_graph.h"
#include "shared_mem_manager.h" #include "shared_mem_manager.h"
...@@ -25,8 +27,7 @@ namespace dgl { ...@@ -25,8 +27,7 @@ namespace dgl {
class HeteroGraph : public BaseHeteroGraph { class HeteroGraph : public BaseHeteroGraph {
public: public:
HeteroGraph( HeteroGraph(
GraphPtr meta_graph, GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type = {}); const std::vector<int64_t>& num_nodes_per_type = {});
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
...@@ -46,31 +47,21 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -46,31 +47,21 @@ class HeteroGraph : public BaseHeteroGraph {
LOG(FATAL) << "Bipartite graph is not mutable."; LOG(FATAL) << "Bipartite graph is not mutable.";
} }
void Clear() override { void Clear() override { LOG(FATAL) << "Bipartite graph is not mutable."; }
LOG(FATAL) << "Bipartite graph is not mutable.";
}
DGLDataType DataType() const override { DGLDataType DataType() const override {
return relation_graphs_[0]->DataType(); return relation_graphs_[0]->DataType();
} }
DGLContext Context() const override { DGLContext Context() const override { return relation_graphs_[0]->Context(); }
return relation_graphs_[0]->Context();
}
bool IsPinned() const override { bool IsPinned() const override { return relation_graphs_[0]->IsPinned(); }
return relation_graphs_[0]->IsPinned();
}
uint8_t NumBits() const override { uint8_t NumBits() const override { return relation_graphs_[0]->NumBits(); }
return relation_graphs_[0]->NumBits();
}
bool IsMultigraph() const override; bool IsMultigraph() const override;
bool IsReadonly() const override { bool IsReadonly() const override { return true; }
return true;
}
uint64_t NumVertices(dgl_type_t vtype) const override { uint64_t NumVertices(dgl_type_t vtype) const override {
CHECK(meta_graph_->HasVertex(vtype)) << "Invalid vertex type: " << vtype; CHECK(meta_graph_->HasVertex(vtype)) << "Invalid vertex type: " << vtype;
...@@ -91,11 +82,13 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -91,11 +82,13 @@ class HeteroGraph : public BaseHeteroGraph {
BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override; BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;
bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override { bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
return GetRelationGraph(etype)->HasEdgeBetween(0, src, dst); return GetRelationGraph(etype)->HasEdgeBetween(0, src, dst);
} }
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override { BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
return GetRelationGraph(etype)->HasEdgesBetween(0, src_ids, dst_ids); return GetRelationGraph(etype)->HasEdgesBetween(0, src_ids, dst_ids);
} }
...@@ -111,15 +104,18 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -111,15 +104,18 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->EdgeId(0, src, dst); return GetRelationGraph(etype)->EdgeId(0, src, dst);
} }
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override { EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIdsAll(0, src, dst); return GetRelationGraph(etype)->EdgeIdsAll(0, src, dst);
} }
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override { IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIdsOne(0, src, dst); return GetRelationGraph(etype)->EdgeIdsOne(0, src, dst);
} }
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const override {
return GetRelationGraph(etype)->FindEdge(0, eid); return GetRelationGraph(etype)->FindEdge(0, eid);
} }
...@@ -143,7 +139,8 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -143,7 +139,8 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->OutEdges(0, vids); return GetRelationGraph(etype)->OutEdges(0, vids);
} }
EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override { EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const override {
return GetRelationGraph(etype)->Edges(0, order); return GetRelationGraph(etype)->Edges(0, order);
} }
...@@ -180,7 +177,7 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -180,7 +177,7 @@ class HeteroGraph : public BaseHeteroGraph {
} }
std::vector<IdArray> GetAdj( std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override { dgl_type_t etype, bool transpose, const std::string& fmt) const override {
return GetRelationGraph(etype)->GetAdj(0, transpose, fmt); return GetRelationGraph(etype)->GetAdj(0, transpose, fmt);
} }
...@@ -196,7 +193,8 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -196,7 +193,8 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->GetCSRMatrix(0); return GetRelationGraph(etype)->GetCSRMatrix(0);
} }
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return GetRelationGraph(etype)->SelectFormat(0, preferred_formats); return GetRelationGraph(etype)->SelectFormat(0, preferred_formats);
} }
...@@ -208,14 +206,17 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -208,14 +206,17 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(0)->GetCreatedFormats(); return GetRelationGraph(0)->GetCreatedFormats();
} }
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override; HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const override;
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override; const std::vector<IdArray>& eids,
bool preserve_nodes = false) const override;
HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override; HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override; FlattenedHeteroGraphPtr Flatten(
const std::vector<dgl_type_t>& etypes) const override;
GraphPtr AsImmutableGraph() const override; GraphPtr AsImmutableGraph() const override;
...@@ -229,26 +230,23 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -229,26 +230,23 @@ class HeteroGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/** @brief Copy the data to another context */ /** @brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext& ctx);
/** /**
* @brief Pin all relation graphs of the current graph. * @brief Pin all relation graphs of the current graph.
* @note The graph will be pinned inplace. Behavior depends on the current context, * @note The graph will be pinned inplace. Behavior depends on the current
* kDGLCPU: will be pinned; * context, kDGLCPU: will be pinned; IsPinned: directly return; kDGLCUDA:
* IsPinned: directly return; * invalid, will throw an error. The context check is deferred to pinning the
* kDGLCUDA: invalid, will throw an error. * NDArray.
* The context check is deferred to pinning the NDArray. */
*/
void PinMemory_() override; void PinMemory_() override;
/** /**
* @brief Unpin all relation graphs of the current graph. * @brief Unpin all relation graphs of the current graph.
* @note The graph will be unpinned inplace. Behavior depends on the current context, * @note The graph will be unpinned inplace. Behavior depends on the current
* IsPinned: will be unpinned; * context, IsPinned: will be unpinned; others: directly return. The context
* others: directly return. * check is deferred to unpinning the NDArray.
* The context check is deferred to unpinning the NDArray. */
*/
void UnpinMemory_(); void UnpinMemory_();
/** /**
...@@ -258,18 +256,22 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -258,18 +256,22 @@ class HeteroGraph : public BaseHeteroGraph {
void RecordStream(DGLStreamHandle stream) override; void RecordStream(DGLStreamHandle stream) override;
/** @brief Copy the data to shared memory. /** @brief Copy the data to shared memory.
* *
* Also save names of node types and edge types of the HeteroGraph object to shared memory * Also save names of node types and edge types of the HeteroGraph object to
*/ * shared memory
*/
static HeteroGraphPtr CopyToSharedMem( static HeteroGraphPtr CopyToSharedMem(
HeteroGraphPtr g, const std::string& name, const std::vector<std::string>& ntypes, HeteroGraphPtr g, const std::string& name,
const std::vector<std::string>& etypes, const std::set<std::string>& fmts); const std::vector<std::string>& ntypes,
const std::vector<std::string>& etypes,
const std::set<std::string>& fmts);
/** @brief Create a heterograph from /** @brief Create a heterograph from
* \return the HeteroGraphPtr, names of node types, names of edge types * \return the HeteroGraphPtr, names of node types, names of edge types
*/ */
static std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>> static std::tuple<
CreateFromSharedMem(const std::string &name); HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
CreateFromSharedMem(const std::string& name);
/** @brief Creat a LineGraph of self */ /** @brief Creat a LineGraph of self */
HeteroGraphPtr LineGraph(bool backtracking) const; HeteroGraphPtr LineGraph(bool backtracking) const;
...@@ -294,25 +296,25 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -294,25 +296,25 @@ class HeteroGraph : public BaseHeteroGraph {
/** @brief The shared memory object for meta info*/ /** @brief The shared memory object for meta info*/
std::shared_ptr<runtime::SharedMemory> shared_mem_; std::shared_ptr<runtime::SharedMemory> shared_mem_;
/** @brief The name of the shared memory. Return empty string if it is not in shared memory. */ /** @brief The name of the shared memory. Return empty string if it is not in
* shared memory. */
std::string SharedMemName() const; std::string SharedMemName() const;
/** @brief template class for Flatten operation /** @brief template class for Flatten operation
* *
* @tparam IdType Graph's index data type, can be int32_t or int64_t * @tparam IdType Graph's index data type, can be int32_t or int64_t
* @param etypes vector of etypes to be falttened * @param etypes vector of etypes to be falttened
* @return pointer of FlattenedHeteroGraphh * @return pointer of FlattenedHeteroGraphh
*/ */
template <class IdType> template <class IdType>
FlattenedHeteroGraphPtr FlattenImpl(const std::vector<dgl_type_t>& etypes) const; FlattenedHeteroGraphPtr FlattenImpl(
const std::vector<dgl_type_t>& etypes) const;
}; };
} // namespace dgl } // namespace dgl
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true); DMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true);
} // namespace dmlc } // namespace dmlc
#endif // DGL_GRAPH_HETEROGRAPH_H_ #endif // DGL_GRAPH_HETEROGRAPH_H_
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
* @brief Definition of neighborhood-based sampler APIs. * @brief Definition of neighborhood-based sampler APIs.
*/ */
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/aten/macro.h> #include <dgl/aten/macro.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/sampling/neighbor.h> #include <dgl/sampling/neighbor.h>
#include "../../../c_api_common.h" #include "../../../c_api_common.h"
#include "../../unit_graph.h" #include "../../unit_graph.h"
...@@ -19,65 +20,60 @@ namespace dgl { ...@@ -19,65 +20,60 @@ namespace dgl {
namespace sampling { namespace sampling {
HeteroSubgraph ExcludeCertainEdges( HeteroSubgraph ExcludeCertainEdges(
const HeteroSubgraph& sg, const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges) {
const std::vector<IdArray>& exclude_edges) { HeteroGraphPtr hg_view = HeteroGraphRef(sg.graph).sptr();
std::vector<IdArray> remain_induced_edges(hg_view->NumEdgeTypes());
HeteroGraphPtr hg_view = HeteroGraphRef(sg.graph).sptr(); std::vector<IdArray> remain_edges(hg_view->NumEdgeTypes());
std::vector<IdArray> remain_induced_edges(hg_view->NumEdgeTypes());
std::vector<IdArray> remain_edges(hg_view->NumEdgeTypes()); for (dgl_type_t etype = 0; etype < hg_view->NumEdgeTypes(); ++etype) {
IdArray edge_ids = Range(
for (dgl_type_t etype = 0; etype < hg_view->NumEdgeTypes(); ++etype) { 0, sg.induced_edges[etype]->shape[0],
IdArray edge_ids = Range(0, sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->ctx);
sg.induced_edges[etype]->shape[0], if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) {
sg.induced_edges[etype]->dtype.bits, remain_edges[etype] = edge_ids;
sg.induced_edges[etype]->ctx); remain_induced_edges[etype] = sg.induced_edges[etype];
if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) { continue;
remain_edges[etype] = edge_ids;
remain_induced_edges[etype] = sg.induced_edges[etype];
continue;
}
ATEN_ID_TYPE_SWITCH(hg_view->DataType(), IdType, {
IdType* idx_data = edge_ids.Ptr<IdType>();
IdType* induced_edges_data = sg.induced_edges[etype].Ptr<IdType>();
const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
std::sort(exclude_edges[etype].Ptr<IdType>(),
exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
IdType outId = 0;
for (IdType i = 0; i != sg.induced_edges[etype]->shape[0]; ++i) {
if (!std::binary_search(exclude_edges_data,
exclude_edges_data + exclude_edges_len,
induced_edges_data[i])) {
induced_edges_data[outId] = induced_edges_data[i];
idx_data[outId] = idx_data[i];
++outId;
}
}
remain_edges[etype] = aten::IndexSelect(edge_ids, 0, outId);
remain_induced_edges[etype] = aten::IndexSelect(sg.induced_edges[etype], 0, outId);
});
} }
HeteroSubgraph subg = hg_view->EdgeSubgraph(remain_edges, true); ATEN_ID_TYPE_SWITCH(hg_view->DataType(), IdType, {
subg.induced_edges = std::move(remain_induced_edges); IdType* idx_data = edge_ids.Ptr<IdType>();
return subg; IdType* induced_edges_data = sg.induced_edges[etype].Ptr<IdType>();
const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
std::sort(
exclude_edges[etype].Ptr<IdType>(),
exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
IdType outId = 0;
for (IdType i = 0; i != sg.induced_edges[etype]->shape[0]; ++i) {
if (!std::binary_search(
exclude_edges_data, exclude_edges_data + exclude_edges_len,
induced_edges_data[i])) {
induced_edges_data[outId] = induced_edges_data[i];
idx_data[outId] = idx_data[i];
++outId;
}
}
remain_edges[etype] = aten::IndexSelect(edge_ids, 0, outId);
remain_induced_edges[etype] =
aten::IndexSelect(sg.induced_edges[etype], 0, outId);
});
}
HeteroSubgraph subg = hg_view->EdgeSubgraph(remain_edges, true);
subg.induced_edges = std::move(remain_induced_edges);
return subg;
} }
HeteroSubgraph SampleNeighbors( HeteroSubgraph SampleNeighbors(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
const std::vector<IdArray>& nodes, const std::vector<int64_t>& fanouts, EdgeDir dir,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const std::vector<NDArray>& prob_or_mask, const std::vector<NDArray>& prob_or_mask,
const std::vector<IdArray>& exclude_edges, const std::vector<IdArray>& exclude_edges, bool replace) {
bool replace) {
// sanity check // sanity check
CHECK_EQ(nodes.size(), hg->NumVertexTypes()) CHECK_EQ(nodes.size(), hg->NumVertexTypes())
<< "Number of node ID tensors must match the number of node types."; << "Number of node ID tensors must match the number of node types.";
CHECK_EQ(fanouts.size(), hg->NumEdgeTypes()) CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
<< "Number of fanout values must match the number of edge types."; << "Number of fanout values must match the number of edge types.";
CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes()) CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())
<< "Number of probability tensors must match the number of edge types."; << "Number of probability tensors must match the number of edge types.";
DGLContext ctx = aten::GetContextOf(nodes); DGLContext ctx = aten::GetContextOf(nodes);
...@@ -87,42 +83,46 @@ HeteroSubgraph SampleNeighbors( ...@@ -87,42 +83,46 @@ HeteroSubgraph SampleNeighbors(
auto pair = hg->meta_graph()->FindEdge(etype); auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype]; const IdArray nodes_ntype =
nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0]; const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0 || fanouts[etype] == 0) { if (num_nodes == 0 || fanouts[etype] == 0) {
// Nothing to sample for this etype, create a placeholder relation graph // Nothing to sample for this etype, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty( subrels[etype] = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(), hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype), hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
hg->NumVertices(dst_vtype), hg->DataType(), ctx);
hg->DataType(), ctx);
induced_edges[etype] = aten::NullArray(hg->DataType(), ctx); induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);
} else { } else {
COOMatrix sampled_coo; COOMatrix sampled_coo;
// sample from one relation graph // sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE; auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt); auto avail_fmt = hg->SelectFormat(etype, req_fmt);
switch (avail_fmt) { switch (avail_fmt) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) { if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseSampling( sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)), aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
nodes_ntype, fanouts[etype], prob_or_mask[etype], replace)); fanouts[etype], prob_or_mask[etype], replace));
} else { } else {
sampled_coo = aten::COORowWiseSampling( sampled_coo = aten::COORowWiseSampling(
hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace); hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype],
prob_or_mask[etype], replace);
} }
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix."; CHECK(dir == EdgeDir::kOut)
<< "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseSampling( sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace); hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype],
prob_or_mask[etype], replace);
break; break;
case SparseFormat::kCSC: case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseSampling( sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace); hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype],
prob_or_mask[etype], replace);
sampled_coo = aten::COOTranspose(sampled_coo); sampled_coo = aten::COOTranspose(sampled_coo);
break; break;
default: default:
...@@ -130,14 +130,15 @@ HeteroSubgraph SampleNeighbors( ...@@ -130,14 +130,15 @@ HeteroSubgraph SampleNeighbors(
} }
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols, hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
sampled_coo.row, sampled_coo.col); sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
induced_edges[etype] = sampled_coo.data; induced_edges[etype] = sampled_coo.data;
} }
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()); ret.graph =
CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes()); ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
if (!exclude_edges.empty()) { if (!exclude_edges.empty()) {
...@@ -147,19 +148,15 @@ HeteroSubgraph SampleNeighbors( ...@@ -147,19 +148,15 @@ HeteroSubgraph SampleNeighbors(
} }
HeteroSubgraph SampleNeighborsEType( HeteroSubgraph SampleNeighborsEType(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray nodes,
const IdArray nodes,
const std::vector<int64_t>& eid2etype_offset, const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& fanouts, const std::vector<int64_t>& fanouts, EdgeDir dir,
EdgeDir dir, const std::vector<FloatArray>& prob, bool replace,
const std::vector<FloatArray>& prob,
bool replace,
bool rowwise_etype_sorted) { bool rowwise_etype_sorted) {
CHECK_EQ(1, hg->NumVertexTypes()) CHECK_EQ(1, hg->NumVertexTypes())
<< "SampleNeighborsEType only work with homogeneous graph"; << "SampleNeighborsEType only work with homogeneous graph";
CHECK_EQ(1, hg->NumEdgeTypes()) CHECK_EQ(1, hg->NumEdgeTypes())
<< "SampleNeighborsEType only work with homogeneous graph"; << "SampleNeighborsEType only work with homogeneous graph";
std::vector<HeteroGraphPtr> subrels(1); std::vector<HeteroGraphPtr> subrels(1);
std::vector<IdArray> induced_edges(1); std::vector<IdArray> induced_edges(1);
...@@ -178,39 +175,39 @@ HeteroSubgraph SampleNeighborsEType( ...@@ -178,39 +175,39 @@ HeteroSubgraph SampleNeighborsEType(
} }
if (num_nodes == 0 || (same_fanout && fanout_value == 0)) { if (num_nodes == 0 || (same_fanout && fanout_value == 0)) {
subrels[etype] = UnitGraph::Empty(1, subrels[etype] = UnitGraph::Empty(
hg->NumVertices(src_vtype), 1, hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
hg->NumVertices(dst_vtype), hg->DataType(), hg->Context());
hg->DataType(), hg->Context());
induced_edges[etype] = aten::NullArray(); induced_edges[etype] = aten::NullArray();
} else { } else {
COOMatrix sampled_coo; COOMatrix sampled_coo;
// sample from graph // sample from graph
// the edge type is stored in etypes // the edge type is stored in etypes
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE; auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt); auto avail_fmt = hg->SelectFormat(etype, req_fmt);
switch (avail_fmt) { switch (avail_fmt) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) { if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling( sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)), aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes,
nodes, eid2etype_offset, fanouts, prob, replace)); eid2etype_offset, fanouts, prob, replace));
} else { } else {
sampled_coo = aten::COORowWisePerEtypeSampling( sampled_coo = aten::COORowWisePerEtypeSampling(
hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob, replace); hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
replace);
} }
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix."; CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling( sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSRMatrix(etype), nodes, eid2etype_offset, hg->GetCSRMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
fanouts, prob, replace, rowwise_etype_sorted); replace, rowwise_etype_sorted);
break; break;
case SparseFormat::kCSC: case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling( sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSCMatrix(etype), nodes, eid2etype_offset, hg->GetCSCMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
fanouts, prob, replace, rowwise_etype_sorted); replace, rowwise_etype_sorted);
sampled_coo = aten::COOTranspose(sampled_coo); sampled_coo = aten::COOTranspose(sampled_coo);
break; break;
default: default:
...@@ -218,32 +215,30 @@ HeteroSubgraph SampleNeighborsEType( ...@@ -218,32 +215,30 @@ HeteroSubgraph SampleNeighborsEType(
} }
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
1, sampled_coo.num_rows, sampled_coo.num_cols, 1, sampled_coo.num_rows, sampled_coo.num_cols, sampled_coo.row,
sampled_coo.row, sampled_coo.col); sampled_coo.col);
induced_edges[etype] = sampled_coo.data; induced_edges[etype] = sampled_coo.data;
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()); ret.graph =
CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes()); ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
} }
HeteroSubgraph SampleNeighborsTopk( HeteroSubgraph SampleNeighborsTopk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
const std::vector<IdArray>& nodes, const std::vector<int64_t>& k, EdgeDir dir,
const std::vector<int64_t>& k, const std::vector<FloatArray>& weight, bool ascending) {
EdgeDir dir,
const std::vector<FloatArray>& weight,
bool ascending) {
// sanity check // sanity check
CHECK_EQ(nodes.size(), hg->NumVertexTypes()) CHECK_EQ(nodes.size(), hg->NumVertexTypes())
<< "Number of node ID tensors must match the number of node types."; << "Number of node ID tensors must match the number of node types.";
CHECK_EQ(k.size(), hg->NumEdgeTypes()) CHECK_EQ(k.size(), hg->NumEdgeTypes())
<< "Number of k values must match the number of edge types."; << "Number of k values must match the number of edge types.";
CHECK_EQ(weight.size(), hg->NumEdgeTypes()) CHECK_EQ(weight.size(), hg->NumEdgeTypes())
<< "Number of weight tensors must match the number of edge types."; << "Number of weight tensors must match the number of edge types.";
std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes()); std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
std::vector<IdArray> induced_edges(hg->NumEdgeTypes()); std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
...@@ -251,214 +246,218 @@ HeteroSubgraph SampleNeighborsTopk( ...@@ -251,214 +246,218 @@ HeteroSubgraph SampleNeighborsTopk(
auto pair = hg->meta_graph()->FindEdge(etype); auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype]; const IdArray nodes_ntype =
nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0]; const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0 || k[etype] == 0) { if (num_nodes == 0 || k[etype] == 0) {
// Nothing to sample for this etype, create a placeholder relation graph // Nothing to sample for this etype, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty( subrels[etype] = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(), hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype), hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
hg->NumVertices(dst_vtype), hg->DataType(), hg->Context());
hg->DataType(), hg->Context());
induced_edges[etype] = aten::NullArray(); induced_edges[etype] = aten::NullArray();
} else { } else {
// sample from one relation graph // sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE; auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt); auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo; COOMatrix sampled_coo;
switch (avail_fmt) { switch (avail_fmt) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) { if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseTopk( sampled_coo = aten::COOTranspose(aten::COORowWiseTopk(
aten::COOTranspose(hg->GetCOOMatrix(etype)), aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
nodes_ntype, k[etype], weight[etype], ascending)); k[etype], weight[etype], ascending));
} else { } else {
sampled_coo = aten::COORowWiseTopk( sampled_coo = aten::COORowWiseTopk(
hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending); hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype],
ascending);
} }
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix."; CHECK(dir == EdgeDir::kOut)
<< "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseTopk( sampled_coo = aten::CSRRowWiseTopk(
hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending); hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype],
ascending);
break; break;
case SparseFormat::kCSC: case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseTopk( sampled_coo = aten::CSRRowWiseTopk(
hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending); hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype],
ascending);
sampled_coo = aten::COOTranspose(sampled_coo); sampled_coo = aten::COOTranspose(sampled_coo);
break; break;
default: default:
LOG(FATAL) << "Unsupported sparse format."; LOG(FATAL) << "Unsupported sparse format.";
} }
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols, hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
sampled_coo.row, sampled_coo.col); sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
induced_edges[etype] = sampled_coo.data; induced_edges[etype] = sampled_coo.data;
} }
} }
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()); ret.graph =
CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes()); ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges); ret.induced_edges = std::move(induced_edges);
return ret; return ret;
} }
HeteroSubgraph SampleNeighborsBiased( HeteroSubgraph SampleNeighborsBiased(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray& nodes, const int64_t fanout,
const IdArray& nodes, const NDArray& bias, const NDArray& tag_offset, const EdgeDir dir,
const int64_t fanout, const bool replace) {
const NDArray& bias, CHECK_EQ(hg->NumEdgeTypes(), 1)
const NDArray& tag_offset, << "Only homogeneous or bipartite graphs are supported";
const EdgeDir dir,
const bool replace
) {
CHECK_EQ(hg->NumEdgeTypes(), 1) << "Only homogeneous or bipartite graphs are supported";
auto pair = hg->meta_graph()->FindEdge(0); auto pair = hg->meta_graph()->FindEdge(0);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const dgl_type_t nodes_ntype = (dir == EdgeDir::kOut) ? src_vtype : dst_vtype; const dgl_type_t nodes_ntype = (dir == EdgeDir::kOut) ? src_vtype : dst_vtype;
// sanity check // sanity check
CHECK_EQ(tag_offset->ndim, 2) << "The shape of tag_offset should be [num_nodes, num_tags + 1]"; CHECK_EQ(tag_offset->ndim, 2)
<< "The shape of tag_offset should be [num_nodes, num_tags + 1]";
CHECK_EQ(tag_offset->shape[0], hg->NumVertices(nodes_ntype)) CHECK_EQ(tag_offset->shape[0], hg->NumVertices(nodes_ntype))
<< "The shape of tag_offset should be [num_nodes, num_tags + 1]"; << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
CHECK_EQ(tag_offset->shape[1], bias->shape[0] + 1) CHECK_EQ(tag_offset->shape[1], bias->shape[0] + 1)
<< "The sizes of tag_offset and bias are inconsistent"; << "The sizes of tag_offset and bias are inconsistent";
const int64_t num_nodes = nodes->shape[0]; const int64_t num_nodes = nodes->shape[0];
HeteroGraphPtr subrel; HeteroGraphPtr subrel;
IdArray induced_edges; IdArray induced_edges;
const dgl_type_t etype = 0; const dgl_type_t etype = 0;
if (num_nodes == 0 || fanout == 0) { if (num_nodes == 0 || fanout == 0) {
// Nothing to sample for this etype, create a placeholder relation graph // Nothing to sample for this etype, create a placeholder relation graph
subrel = UnitGraph::Empty( subrel = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(), hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype), hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype), hg->DataType(),
hg->NumVertices(dst_vtype), hg->Context());
hg->DataType(), hg->Context()); induced_edges = aten::NullArray();
induced_edges = aten::NullArray(); } else {
} else { // sample from one relation graph
// sample from one relation graph const auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
const auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE; const auto created_fmt = hg->GetCreatedFormats();
const auto created_fmt = hg->GetCreatedFormats(); COOMatrix sampled_coo;
COOMatrix sampled_coo;
switch (req_fmt) { switch (req_fmt) {
case CSR_CODE: case CSR_CODE:
CHECK(created_fmt & CSR_CODE) << "A sorted CSR Matrix is required."; CHECK(created_fmt & CSR_CODE) << "A sorted CSR Matrix is required.";
sampled_coo = aten::CSRRowWiseSamplingBiased( sampled_coo = aten::CSRRowWiseSamplingBiased(
hg->GetCSRMatrix(etype), nodes, fanout, tag_offset, bias, replace); hg->GetCSRMatrix(etype), nodes, fanout, tag_offset, bias, replace);
break; break;
case CSC_CODE: case CSC_CODE:
CHECK(created_fmt & CSC_CODE) << "A sorted CSC Matrix is required."; CHECK(created_fmt & CSC_CODE) << "A sorted CSC Matrix is required.";
sampled_coo = aten::CSRRowWiseSamplingBiased( sampled_coo = aten::CSRRowWiseSamplingBiased(
hg->GetCSCMatrix(etype), nodes, fanout, tag_offset, bias, replace); hg->GetCSCMatrix(etype), nodes, fanout, tag_offset, bias, replace);
sampled_coo = aten::COOTranspose(sampled_coo); sampled_coo = aten::COOTranspose(sampled_coo);
break; break;
default: default:
LOG(FATAL) << "Unsupported sparse format."; LOG(FATAL) << "Unsupported sparse format.";
}
subrel = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
induced_edges = sampled_coo.data;
} }
subrel = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
induced_edges = sampled_coo.data;
}
HeteroSubgraph ret; HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType()); ret.graph =
CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes()); ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = {induced_edges}; ret.induced_edges = {induced_edges};
return ret; return ret;
} }
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType") DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
IdArray nodes = args[1]; IdArray nodes = args[1];
const std::vector<int64_t>& eid2etype_offset = ListValueToVector<int64_t>(args[2]); const std::vector<int64_t>& eid2etype_offset =
IdArray fanout = args[3]; ListValueToVector<int64_t>(args[2]);
const std::string dir_str = args[4]; IdArray fanout = args[3];
const auto& prob = ListValueToVector<FloatArray>(args[5]); const std::string dir_str = args[4];
const bool replace = args[6]; const auto& prob = ListValueToVector<FloatArray>(args[5]);
const bool rowwise_etype_sorted = args[7]; const bool replace = args[6];
const bool rowwise_etype_sorted = args[7];
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\"."; CHECK(dir_str == "in" || dir_str == "out")
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut; << "Invalid edge direction. Must be \"in\" or \"out\".";
CHECK_INT64(fanout, "fanout"); EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>(); CHECK_INT64(fanout, "fanout");
std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>();
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsEType( std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace, rowwise_etype_sorted); *subg = sampling::SampleNeighborsEType(
*rv = HeteroSubgraphRef(subg); hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace,
}); rowwise_etype_sorted);
*rv = HeteroSubgraphRef(subg);
});
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors") DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]); const auto& nodes = ListValueToVector<IdArray>(args[1]);
IdArray fanouts_array = args[2]; IdArray fanouts_array = args[2];
const auto& fanouts = fanouts_array.ToVector<int64_t>(); const auto& fanouts = fanouts_array.ToVector<int64_t>();
const std::string dir_str = args[3]; const std::string dir_str = args[3];
const auto& prob_or_mask = ListValueToVector<NDArray>(args[4]); const auto& prob_or_mask = ListValueToVector<NDArray>(args[4]);
const auto& exclude_edges = ListValueToVector<IdArray>(args[5]); const auto& exclude_edges = ListValueToVector<IdArray>(args[5]);
const bool replace = args[6]; const bool replace = args[6];
CHECK(dir_str == "in" || dir_str == "out") CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\"."; << "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut; EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph); std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighbors( *subg = sampling::SampleNeighbors(
hg.sptr(), nodes, fanouts, dir, prob_or_mask, exclude_edges, replace); hg.sptr(), nodes, fanouts, dir, prob_or_mask, exclude_edges, replace);
*rv = HeteroSubgraphRef(subg); *rv = HeteroSubgraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk") DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]); const auto& nodes = ListValueToVector<IdArray>(args[1]);
IdArray k_array = args[2]; IdArray k_array = args[2];
const auto& k = k_array.ToVector<int64_t>(); const auto& k = k_array.ToVector<int64_t>();
const std::string dir_str = args[3]; const std::string dir_str = args[3];
const auto& weight = ListValueToVector<FloatArray>(args[4]); const auto& weight = ListValueToVector<FloatArray>(args[4]);
const bool ascending = args[5]; const bool ascending = args[5];
CHECK(dir_str == "in" || dir_str == "out") CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\"."; << "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut; EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph); std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsTopk( *subg = sampling::SampleNeighborsTopk(
hg.sptr(), nodes, k, dir, weight, ascending); hg.sptr(), nodes, k, dir, weight, ascending);
*rv = HeteroGraphRef(subg); *rv = HeteroGraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsBiased") DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsBiased")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const IdArray nodes = args[1]; const IdArray nodes = args[1];
const int64_t fanout = args[2]; const int64_t fanout = args[2];
const NDArray bias = args[3]; const NDArray bias = args[3];
const NDArray tag_offset = args[4]; const NDArray tag_offset = args[4];
const std::string dir_str = args[5]; const std::string dir_str = args[5];
const bool replace = args[6]; const bool replace = args[6];
CHECK(dir_str == "in" || dir_str == "out") CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\"."; << "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut; EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph); std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsBiased( *subg = sampling::SampleNeighborsBiased(
hg.sptr(), nodes, fanout, bias, tag_offset, dir, replace); hg.sptr(), nodes, fanout, bias, tag_offset, dir, replace);
*rv = HeteroGraphRef(subg); *rv = HeteroGraphRef(subg);
}); });
} // namespace sampling } // namespace sampling
} // namespace dgl } // namespace dgl
...@@ -65,10 +65,11 @@ DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true); ...@@ -65,10 +65,11 @@ DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
namespace dgl { namespace dgl {
namespace serialize { namespace serialize {
bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data, bool SaveDGLGraphs(
std::vector<NamedTensor> labels_list) { std::string filename, List<GraphData> graph_data,
std::vector<NamedTensor> labels_list) {
auto fs = std::unique_ptr<SeekStream>(dynamic_cast<SeekStream *>( auto fs = std::unique_ptr<SeekStream>(dynamic_cast<SeekStream *>(
SeekStream::Create(filename.c_str(), "w", true))); SeekStream::Create(filename.c_str(), "w", true)));
CHECK(fs) << "File name " << filename << " is not a valid local file name"; CHECK(fs) << "File name " << filename << " is not a valid local file name";
// Write DGL MetaData // Write DGL MetaData
...@@ -110,10 +111,11 @@ bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data, ...@@ -110,10 +111,11 @@ bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data,
return true; return true;
} }
StorageMetaData LoadDGLGraphs(const std::string &filename, StorageMetaData LoadDGLGraphs(
std::vector<dgl_id_t> idx_list, bool onlyMeta) { const std::string &filename, std::vector<dgl_id_t> idx_list,
bool onlyMeta) {
auto fs = std::unique_ptr<SeekStream>( auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), true)); SeekStream::CreateForRead(filename.c_str(), true));
CHECK(fs) << "Filename is invalid"; CHECK(fs) << "Filename is invalid";
// Read DGL MetaData // Read DGL MetaData
uint64_t magicNum, graphType, version; uint64_t magicNum, graphType, version;
...@@ -153,7 +155,7 @@ StorageMetaData LoadDGLGraphs(const std::string &filename, ...@@ -153,7 +155,7 @@ StorageMetaData LoadDGLGraphs(const std::string &filename,
for (uint64_t i = 0; i < num_graph; ++i) { for (uint64_t i = 0; i < num_graph; ++i) {
GraphData gdata = GraphData::Create(); GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr = GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>()); const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr); fs->Read(gdata_ptr);
gdata_refs.push_back(gdata); gdata_refs.push_back(gdata);
} }
...@@ -165,12 +167,12 @@ StorageMetaData LoadDGLGraphs(const std::string &filename, ...@@ -165,12 +167,12 @@ StorageMetaData LoadDGLGraphs(const std::string &filename,
for (uint64_t i = 0; i < idx_list.size(); ++i) { for (uint64_t i = 0; i < idx_list.size(); ++i) {
auto gid = idx_list[i]; auto gid = idx_list[i];
CHECK((gid < graph_indices.size()) && (gid >= 0)) CHECK((gid < graph_indices.size()) && (gid >= 0))
<< "ID " << gid << "ID " << gid
<< " in idx_list is out of bound. Please check your idx_list."; << " in idx_list is out of bound. Please check your idx_list.";
fs->Seek(graph_indices[gid]); fs->Seek(graph_indices[gid]);
GraphData gdata = GraphData::Create(); GraphData gdata = GraphData::Create();
GraphDataObject *gdata_ptr = GraphDataObject *gdata_ptr =
const_cast<GraphDataObject *>(gdata.as<GraphDataObject>()); const_cast<GraphDataObject *>(gdata.as<GraphDataObject>());
fs->Read(gdata_ptr); fs->Read(gdata_ptr);
gdata_refs.push_back(gdata); gdata_refs.push_back(gdata);
} }
...@@ -179,9 +181,9 @@ StorageMetaData LoadDGLGraphs(const std::string &filename, ...@@ -179,9 +181,9 @@ StorageMetaData LoadDGLGraphs(const std::string &filename,
return metadata; return metadata;
} }
void GraphDataObject::SetData(ImmutableGraphPtr gptr, void GraphDataObject::SetData(
Map<std::string, Value> node_tensors, ImmutableGraphPtr gptr, Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors) { Map<std::string, Value> edge_tensors) {
this->gptr = gptr; this->gptr = gptr;
for (auto kv : node_tensors) { for (auto kv : node_tensors) {
...@@ -227,7 +229,7 @@ ImmutableGraphPtr BatchLoadedGraphs(std::vector<GraphData> gdata_list) { ...@@ -227,7 +229,7 @@ ImmutableGraphPtr BatchLoadedGraphs(std::vector<GraphData> gdata_list) {
gptrs.push_back(static_cast<GraphPtr>(gdata->gptr)); gptrs.push_back(static_cast<GraphPtr>(gdata->gptr));
} }
ImmutableGraphPtr imGPtr = ImmutableGraphPtr imGPtr =
std::dynamic_pointer_cast<ImmutableGraph>(GraphOp::DisjointUnion(gptrs)); std::dynamic_pointer_cast<ImmutableGraph>(GraphOp::DisjointUnion(gptrs));
return imGPtr; return imGPtr;
} }
...@@ -243,21 +245,18 @@ ImmutableGraphPtr ToImmutableGraph(GraphPtr g) { ...@@ -243,21 +245,18 @@ ImmutableGraphPtr ToImmutableGraph(GraphPtr g) {
IdArray dsts_array = earray.dst; IdArray dsts_array = earray.dst;
bool row_sorted, col_sorted; bool row_sorted, col_sorted;
std::tie(row_sorted, col_sorted) = COOIsSorted( std::tie(row_sorted, col_sorted) = COOIsSorted(aten::COOMatrix(
aten::COOMatrix(mgr->NumVertices(), mgr->NumVertices(), srcs_array, mgr->NumVertices(), mgr->NumVertices(), srcs_array, dsts_array));
dsts_array));
ImmutableGraphPtr imgptr = ImmutableGraphPtr imgptr = ImmutableGraph::CreateFromCOO(
ImmutableGraph::CreateFromCOO(mgr->NumVertices(), srcs_array, dsts_array, mgr->NumVertices(), srcs_array, dsts_array, row_sorted, col_sorted);
row_sorted, col_sorted);
return imgptr; return imgptr;
} }
} }
void StorageMetaDataObject::SetMetaData(dgl_id_t num_graph, void StorageMetaDataObject::SetMetaData(
std::vector<int64_t> nodes_num_list, dgl_id_t num_graph, std::vector<int64_t> nodes_num_list,
std::vector<int64_t> edges_num_list, std::vector<int64_t> edges_num_list, std::vector<NamedTensor> labels_list) {
std::vector<NamedTensor> labels_list) {
this->num_graph = num_graph; this->num_graph = num_graph;
this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list))); this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list)));
this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list))); this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list)));
......
...@@ -67,60 +67,60 @@ namespace dgl { ...@@ -67,60 +67,60 @@ namespace dgl {
namespace serialize { namespace serialize {
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_MakeGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef gptr = args[0]; GraphRef gptr = args[0];
ImmutableGraphPtr imGPtr = ToImmutableGraph(gptr.sptr()); ImmutableGraphPtr imGPtr = ToImmutableGraph(gptr.sptr());
Map<std::string, Value> node_tensors = args[1]; Map<std::string, Value> node_tensors = args[1];
Map<std::string, Value> edge_tensors = args[2]; Map<std::string, Value> edge_tensors = args[2];
GraphData gd = GraphData::Create(); GraphData gd = GraphData::Create();
gd->SetData(imGPtr, node_tensors, edge_tensors); gd->SetData(imGPtr, node_tensors, edge_tensors);
*rv = gd; *rv = gd;
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_SaveDGLGraphs_V0") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_SaveDGLGraphs_V0")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
List<GraphData> graph_data = args[1]; List<GraphData> graph_data = args[1];
Map<std::string, Value> labels = args[2]; Map<std::string, Value> labels = args[2];
std::vector<NamedTensor> labels_list; std::vector<NamedTensor> labels_list;
for (auto kv : labels) { for (auto kv : labels) {
std::string name = kv.first; std::string name = kv.first;
Value v = kv.second; Value v = kv.second;
NDArray ndarray = static_cast<NDArray>(v->data); NDArray ndarray = static_cast<NDArray>(v->data);
labels_list.emplace_back(name, ndarray); labels_list.emplace_back(name, ndarray);
} }
SaveDGLGraphs(filename, graph_data, labels_list); SaveDGLGraphs(filename, graph_data, labels_list);
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataGraphHandle") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataGraphHandle")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0]; GraphData gdata = args[0];
*rv = gdata->gptr; *rv = gdata->gptr;
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataNodeTensors") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataNodeTensors")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0]; GraphData gdata = args[0];
Map<std::string, Value> rvmap; Map<std::string, Value> rvmap;
for (auto kv : gdata->node_tensors) { for (auto kv : gdata->node_tensors) {
rvmap.Set(kv.first, Value(MakeValue(kv.second))); rvmap.Set(kv.first, Value(MakeValue(kv.second)));
} }
*rv = rvmap; *rv = rvmap;
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataEdgeTensors") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GDataEdgeTensors")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphData gdata = args[0]; GraphData gdata = args[0];
Map<std::string, Value> rvmap; Map<std::string, Value> rvmap;
for (auto kv : gdata->edge_tensors) { for (auto kv : gdata->edge_tensors) {
rvmap.Set(kv.first, Value(MakeValue(kv.second))); rvmap.Set(kv.first, Value(MakeValue(kv.second)));
} }
*rv = rvmap; *rv = rvmap;
}); });
uint64_t GetFileVersion(const std::string &filename) { uint64_t GetFileVersion(const std::string &filename) {
auto fs = std::unique_ptr<SeekStream>( auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false)); SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File " << filename << " not found"; CHECK(fs) << "File " << filename << " not found";
uint64_t magicNum, version; uint64_t magicNum, version;
fs->Read(&magicNum); fs->Read(&magicNum);
...@@ -130,35 +130,36 @@ uint64_t GetFileVersion(const std::string &filename) { ...@@ -130,35 +130,36 @@ uint64_t GetFileVersion(const std::string &filename) {
} }
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GetFileVersion") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_GetFileVersion")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
*rv = static_cast<int64_t>(GetFileVersion(filename)); *rv = static_cast<int64_t>(GetFileVersion(filename));
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V1") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V1")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
List<Value> idxs = args[1]; List<Value> idxs = args[1];
bool onlyMeta = args[2]; bool onlyMeta = args[2];
auto idx_list = ListValueToVector<dgl_id_t>(idxs); auto idx_list = ListValueToVector<dgl_id_t>(idxs);
*rv = LoadDGLGraphs(filename, idx_list, onlyMeta); *rv = LoadDGLGraphs(filename, idx_list, onlyMeta);
}); });
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLAsHeteroGraph") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLAsHeteroGraph")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); ImmutableGraphPtr ig =
CHECK(ig) << "graph is not readonly"; std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
*rv = HeteroGraphRef(ig->AsHeteroGraph()); CHECK(ig) << "graph is not readonly";
}); *rv = HeteroGraphRef(ig->AsHeteroGraph());
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V2") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V2")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
List<Value> idxs = args[1]; List<Value> idxs = args[1];
auto idx_list = ListValueToVector<dgl_id_t>(idxs); auto idx_list = ListValueToVector<dgl_id_t>(idxs);
*rv = List<HeteroGraphData>(LoadHeteroGraphs(filename, idx_list)); *rv = List<HeteroGraphData>(LoadHeteroGraphs(filename, idx_list));
}); });
} // namespace serialize } // namespace serialize
} // namespace dgl } // namespace dgl
...@@ -48,8 +48,8 @@ ...@@ -48,8 +48,8 @@
#include <vector> #include <vector>
#include "../heterograph.h" #include "../heterograph.h"
#include "./graph_serialize.h"
#include "./dglstream.h" #include "./dglstream.h"
#include "./graph_serialize.h"
#include "dmlc/memory_io.h" #include "dmlc/memory_io.h"
namespace dgl { namespace dgl {
...@@ -61,11 +61,11 @@ using dmlc::Stream; ...@@ -61,11 +61,11 @@ using dmlc::Stream;
using dmlc::io::FileSystem; using dmlc::io::FileSystem;
using dmlc::io::URI; using dmlc::io::URI;
bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata, bool SaveHeteroGraphs(
const std::vector<NamedTensor> &nd_list, std::string filename, List<HeteroGraphData> hdata,
dgl_format_code_t formats) { const std::vector<NamedTensor> &nd_list, dgl_format_code_t formats) {
auto fs = std::unique_ptr<DGLStream>( auto fs = std::unique_ptr<DGLStream>(
DGLStream::Create(filename.c_str(), "w", false, formats)); DGLStream::Create(filename.c_str(), "w", false, formats));
CHECK(fs->IsValid()) << "File name " << filename << " is not a valid name"; CHECK(fs->IsValid()) << "File name " << filename << " is not a valid name";
// Write DGL MetaData // Write DGL MetaData
...@@ -91,7 +91,7 @@ bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata, ...@@ -91,7 +91,7 @@ bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata,
label_fs->Write(nd_list); label_fs->Write(nd_list);
uint64_t gdata_start_pos = uint64_t gdata_start_pos =
fs->Count() + sizeof(uint64_t) + labels_blob.size(); fs->Count() + sizeof(uint64_t) + labels_blob.size();
// Write start position of gdata, which can be skipped when only reading gdata // Write start position of gdata, which can be skipped when only reading gdata
// And label dict // And label dict
...@@ -120,10 +120,10 @@ bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata, ...@@ -120,10 +120,10 @@ bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata,
return true; return true;
} }
std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename, std::vector<HeteroGraphData> LoadHeteroGraphs(
std::vector<dgl_id_t> idx_list) { const std::string &filename, std::vector<dgl_id_t> idx_list) {
auto fs = std::unique_ptr<SeekStream>( auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false)); SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File name " << filename << " is not a valid name"; CHECK(fs) << "File name " << filename << " is not a valid name";
// Read DGL MetaData // Read DGL MetaData
uint64_t magicNum, graphType, version, num_graph; uint64_t magicNum, graphType, version, num_graph;
...@@ -172,8 +172,8 @@ std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename, ...@@ -172,8 +172,8 @@ std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename,
for (uint64_t i = 0; i < idx_list.size(); ++i) { for (uint64_t i = 0; i < idx_list.size(); ++i) {
auto gid = idx_list[i]; auto gid = idx_list[i];
CHECK((gid < graph_indices.size()) && (gid >= 0)) CHECK((gid < graph_indices.size()) && (gid >= 0))
<< "ID " << gid << "ID " << gid
<< " in idx_list is out of bound. Please check your idx_list."; << " in idx_list is out of bound. Please check your idx_list.";
fs->Seek(graph_indices[gid]); fs->Seek(graph_indices[gid]);
HeteroGraphData gdata = HeteroGraphData::Create(); HeteroGraphData gdata = HeteroGraphData::Create();
auto hetero_data = gdata.sptr(); auto hetero_data = gdata.sptr();
...@@ -187,7 +187,7 @@ std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename, ...@@ -187,7 +187,7 @@ std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename,
std::vector<NamedTensor> LoadLabels_V2(const std::string &filename) { std::vector<NamedTensor> LoadLabels_V2(const std::string &filename) {
auto fs = std::unique_ptr<SeekStream>( auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false)); SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File name " << filename << " is not a valid name"; CHECK(fs) << "File name " << filename << " is not a valid name";
// Read DGL MetaData // Read DGL MetaData
uint64_t magicNum, graphType, version, num_graph; uint64_t magicNum, graphType, version, num_graph;
...@@ -207,106 +207,107 @@ std::vector<NamedTensor> LoadLabels_V2(const std::string &filename) { ...@@ -207,106 +207,107 @@ std::vector<NamedTensor> LoadLabels_V2(const std::string &filename) {
} }
DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_MakeHeteroGraphData") DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_MakeHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Map<std::string, Value>> ndata = args[1]; List<Map<std::string, Value>> ndata = args[1];
List<Map<std::string, Value>> edata = args[2]; List<Map<std::string, Value>> edata = args[2];
List<Value> ntype_names = args[3]; List<Value> ntype_names = args[3];
List<Value> etype_names = args[4]; List<Value> etype_names = args[4];
*rv = HeteroGraphData::Create(hg.sptr(), ndata, edata, ntype_names, *rv = HeteroGraphData::Create(
etype_names); hg.sptr(), ndata, edata, ntype_names, etype_names);
}); });
DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData") DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
List<HeteroGraphData> hgdata = args[1]; List<HeteroGraphData> hgdata = args[1];
Map<std::string, Value> nd_map = args[2]; Map<std::string, Value> nd_map = args[2];
List<Value> formats = args[3]; List<Value> formats = args[3];
std::vector<SparseFormat> formats_vec; std::vector<SparseFormat> formats_vec;
for (const auto& val : formats) { for (const auto &val : formats) {
formats_vec.push_back(ParseSparseFormat(val->data)); formats_vec.push_back(ParseSparseFormat(val->data));
} }
const auto formats_code = SparseFormatsToCode(formats_vec); const auto formats_code = SparseFormatsToCode(formats_vec);
std::vector<NamedTensor> nd_list; std::vector<NamedTensor> nd_list;
for (auto kv : nd_map) { for (auto kv : nd_map) {
NDArray ndarray = static_cast<NDArray>(kv.second->data); NDArray ndarray = static_cast<NDArray>(kv.second->data);
nd_list.emplace_back(kv.first, ndarray); nd_list.emplace_back(kv.first, ndarray);
} }
*rv = dgl::serialize::SaveHeteroGraphs(filename, hgdata, nd_list, formats_code); *rv = dgl::serialize::SaveHeteroGraphs(
}); filename, hgdata, nd_list, formats_code);
});
DGL_REGISTER_GLOBAL( DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetGindexFromHeteroGraphData") "data.heterograph_serialize._CAPI_GetGindexFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0]; HeteroGraphData hdata = args[0];
*rv = HeteroGraphRef(hdata->gptr); *rv = HeteroGraphRef(hdata->gptr);
}); });
DGL_REGISTER_GLOBAL( DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetEtypesFromHeteroGraphData") "data.heterograph_serialize._CAPI_GetEtypesFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0]; HeteroGraphData hdata = args[0];
List<Value> etype_names; List<Value> etype_names;
for (const auto &name : hdata->etype_names) { for (const auto &name : hdata->etype_names) {
etype_names.push_back(Value(MakeValue(name))); etype_names.push_back(Value(MakeValue(name)));
} }
*rv = etype_names; *rv = etype_names;
}); });
DGL_REGISTER_GLOBAL( DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetNtypesFromHeteroGraphData") "data.heterograph_serialize._CAPI_GetNtypesFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0]; HeteroGraphData hdata = args[0];
List<Value> ntype_names; List<Value> ntype_names;
for (auto name : hdata->ntype_names) { for (auto name : hdata->ntype_names) {
ntype_names.push_back(Value(MakeValue(name))); ntype_names.push_back(Value(MakeValue(name)));
} }
*rv = ntype_names; *rv = ntype_names;
}); });
DGL_REGISTER_GLOBAL( DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetNDataFromHeteroGraphData") "data.heterograph_serialize._CAPI_GetNDataFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0]; HeteroGraphData hdata = args[0];
List<List<Value>> ntensors; List<List<Value>> ntensors;
for (auto tensor_list : hdata->node_tensors) { for (auto tensor_list : hdata->node_tensors) {
List<Value> nlist; List<Value> nlist;
for (const auto &kv : tensor_list) { for (const auto &kv : tensor_list) {
nlist.push_back(Value(MakeValue(kv.first))); nlist.push_back(Value(MakeValue(kv.first)));
nlist.push_back(Value(MakeValue(kv.second))); nlist.push_back(Value(MakeValue(kv.second)));
}
ntensors.push_back(nlist);
} }
ntensors.push_back(nlist); *rv = ntensors;
} });
*rv = ntensors;
});
DGL_REGISTER_GLOBAL( DGL_REGISTER_GLOBAL(
"data.heterograph_serialize._CAPI_GetEDataFromHeteroGraphData") "data.heterograph_serialize._CAPI_GetEDataFromHeteroGraphData")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphData hdata = args[0]; HeteroGraphData hdata = args[0];
List<List<Value>> etensors; List<List<Value>> etensors;
for (auto tensor_list : hdata->edge_tensors) { for (auto tensor_list : hdata->edge_tensors) {
List<Value> elist; List<Value> elist;
for (const auto &kv : tensor_list) { for (const auto &kv : tensor_list) {
elist.push_back(Value(MakeValue(kv.first))); elist.push_back(Value(MakeValue(kv.first)));
elist.push_back(Value(MakeValue(kv.second))); elist.push_back(Value(MakeValue(kv.second)));
}
etensors.push_back(elist);
} }
etensors.push_back(elist); *rv = etensors;
} });
*rv = etensors;
});
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadLabels_V2") DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadLabels_V2")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
auto labels_list = LoadLabels_V2(filename); auto labels_list = LoadLabels_V2(filename);
Map<std::string, Value> rvmap; Map<std::string, Value> rvmap;
for (auto kv : labels_list) { for (auto kv : labels_list) {
rvmap.Set(kv.first, Value(MakeValue(kv.second))); rvmap.Set(kv.first, Value(MakeValue(kv.second)));
} }
*rv = rvmap; *rv = rvmap;
}); });
} // namespace serialize } // namespace serialize
} // namespace dgl } // namespace dgl
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
* @file graph/unit_graph.cc * @file graph/unit_graph.cc
* @brief UnitGraph graph implementation * @brief UnitGraph graph implementation
*/ */
#include "./unit_graph.h"
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/lazy.h> #include <dgl/lazy.h>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./unit_graph.h"
#include "./serialize/dglstream.h" #include "./serialize/dglstream.h"
namespace dgl { namespace dgl {
...@@ -64,17 +65,17 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -64,17 +65,17 @@ class UnitGraph::COO : public BaseHeteroGraph {
public: public:
COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src, COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src,
IdArray dst, bool row_sorted = false, bool col_sorted = false) IdArray dst, bool row_sorted = false, bool col_sorted = false)
: BaseHeteroGraph(metagraph) { : BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(src)); CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst)); CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length."; CHECK_EQ(src->shape[0], dst->shape[0])
adj_ = aten::COOMatrix{num_src, num_dst, src, dst, << "Input arrays should have the same length.";
NullArray(), adj_ = aten::COOMatrix{num_src, num_dst, src, dst,
row_sorted, col_sorted}; NullArray(), row_sorted, col_sorted};
} }
COO(GraphPtr metagraph, const aten::COOMatrix& coo) COO(GraphPtr metagraph, const aten::COOMatrix& coo)
: BaseHeteroGraph(metagraph), adj_(coo) { : BaseHeteroGraph(metagraph), adj_(coo) {
// Data index should not be inherited. Edges in COO format are always // Data index should not be inherited. Edges in COO format are always
// assigned ids from 0 to num_edges - 1. // assigned ids from 0 to num_edges - 1.
CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data."; CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
...@@ -83,30 +84,23 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -83,30 +84,23 @@ class UnitGraph::COO : public BaseHeteroGraph {
COO() { COO() {
// set magic num_rows/num_cols to mark it as undefined // set magic num_rows/num_cols to mark it as undefined
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
// supported
adj_.num_rows = -1; adj_.num_rows = -1;
adj_.num_cols = -1; adj_.num_cols = -1;
}; };
bool defined() const { bool defined() const { return (adj_.num_rows >= 0) && (adj_.num_cols >= 0); }
return (adj_.num_rows >= 0) && (adj_.num_cols >= 0);
}
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const { return 0; }
return 0;
}
inline dgl_type_t DstType() const { inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
return NumVertexTypes() == 1? 0 : 1;
}
inline dgl_type_t EdgeType() const { inline dgl_type_t EdgeType() const { return 0; }
return 0;
}
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. " LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
<< "The relation graph is simply this graph itself."; << "The relation graph is simply this graph itself.";
return {}; return {};
} }
...@@ -122,67 +116,44 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -122,67 +116,44 @@ class UnitGraph::COO : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
void Clear() override { void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
DGLDataType DataType() const override { DGLDataType DataType() const override { return adj_.row->dtype; }
return adj_.row->dtype;
}
DGLContext Context() const override { DGLContext Context() const override { return adj_.row->ctx; }
return adj_.row->ctx;
}
bool IsPinned() const override { bool IsPinned() const override { return adj_.is_pinned; }
return adj_.is_pinned;
}
uint8_t NumBits() const override { uint8_t NumBits() const override { return adj_.row->dtype.bits; }
return adj_.row->dtype.bits;
}
COO AsNumBits(uint8_t bits) const { COO AsNumBits(uint8_t bits) const {
if (NumBits() == bits) if (NumBits() == bits) return *this;
return *this;
COO ret( COO ret(
meta_graph_, meta_graph_, adj_.num_rows, adj_.num_cols,
adj_.num_rows, adj_.num_cols, aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.col, bits));
aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits));
return ret; return ret;
} }
COO CopyTo(const DGLContext &ctx) const { COO CopyTo(const DGLContext& ctx) const {
if (Context() == ctx) if (Context() == ctx) return *this;
return *this;
return COO(meta_graph_, adj_.CopyTo(ctx)); return COO(meta_graph_, adj_.CopyTo(ctx));
} }
/** @brief Pin the adj_: COOMatrix of the COO graph. */ /** @brief Pin the adj_: COOMatrix of the COO graph. */
void PinMemory_() { void PinMemory_() { adj_.PinMemory_(); }
adj_.PinMemory_();
}
/** @brief Unpin the adj_: COOMatrix of the COO graph. */ /** @brief Unpin the adj_: COOMatrix of the COO graph. */
void UnpinMemory_() { void UnpinMemory_() { adj_.UnpinMemory_(); }
adj_.UnpinMemory_();
}
/** @brief Record stream for the adj_: COOMatrix of the COO graph. */ /** @brief Record stream for the adj_: COOMatrix of the COO graph. */
void RecordStream(DGLStreamHandle stream) override { void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream); adj_.RecordStream(stream);
} }
bool IsMultigraph() const override { bool IsMultigraph() const override { return aten::COOHasDuplicate(adj_); }
return aten::COOHasDuplicate(adj_);
}
bool IsReadonly() const override { bool IsReadonly() const override { return true; }
return true;
}
uint64_t NumVertices(dgl_type_t vtype) const override { uint64_t NumVertices(dgl_type_t vtype) const override {
if (vtype == SrcType()) { if (vtype == SrcType()) {
...@@ -208,13 +179,15 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -208,13 +179,15 @@ class UnitGraph::COO : public BaseHeteroGraph {
return {}; return {};
} }
bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override { bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src; CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst; CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::COOIsNonZero(adj_, src, dst); return aten::COOIsNonZero(adj_, src, dst);
} }
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override { BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::COOIsNonZero(adj_, src_ids, dst_ids); return aten::COOIsNonZero(adj_, src_ids, dst_ids);
...@@ -236,18 +209,21 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -236,18 +209,21 @@ class UnitGraph::COO : public BaseHeteroGraph {
return aten::COOGetAllData(adj_, src, dst); return aten::COOGetAllData(adj_, src, dst);
} }
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override { EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const override {
CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst); const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
return EdgeArray{arrs[0], arrs[1], arrs[2]}; return EdgeArray{arrs[0], arrs[1], arrs[2]};
} }
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override { IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const override {
return aten::COOGetData(adj_, src, dst); return aten::COOGetData(adj_, src, dst);
} }
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const override {
CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid; CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid); const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid); const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
...@@ -256,18 +232,19 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -256,18 +232,19 @@ class UnitGraph::COO : public BaseHeteroGraph {
EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override { EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array"; CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
BUG_IF_FAIL(aten::IsNullArray(adj_.data)) << BUG_IF_FAIL(aten::IsNullArray(adj_.data))
"FindEdges requires the internal COO matrix not having EIDs."; << "FindEdges requires the internal COO matrix not having EIDs.";
return EdgeArray{aten::IndexSelect(adj_.row, eids), return EdgeArray{
aten::IndexSelect(adj_.col, eids), aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),
eids}; eids};
} }
EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override { EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
IdArray ret_src, ret_eid; IdArray ret_src, ret_eid;
std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices( std::tie(ret_eid, ret_src) =
aten::COOTranspose(adj_), vid); aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), vid);
IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx); IdArray ret_dst =
aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid}; return EdgeArray{ret_src, ret_dst, ret_eid};
} }
...@@ -281,7 +258,8 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -281,7 +258,8 @@ class UnitGraph::COO : public BaseHeteroGraph {
EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override { EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
IdArray ret_dst, ret_eid; IdArray ret_dst, ret_eid;
std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid); std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx); IdArray ret_src =
aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid}; return EdgeArray{ret_src, ret_dst, ret_eid};
} }
...@@ -292,10 +270,11 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -292,10 +270,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return EdgeArray{row, coosubmat.col, coosubmat.data}; return EdgeArray{row, coosubmat.col, coosubmat.data};
} }
EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override { EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const override {
CHECK(order.empty() || order == std::string("eid")) CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \"" << "COO only support Edges of order \"eid\", but got \"" << order
<< order << "\"."; << "\".";
IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context()); IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
return EdgeArray{adj_.row, adj_.col, rst_eid}; return EdgeArray{adj_.row, adj_.col, rst_eid};
} }
...@@ -341,7 +320,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -341,7 +320,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
} }
std::vector<IdArray> GetAdj( std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override { dgl_type_t etype, bool transpose, const std::string& fmt) const override {
CHECK(fmt == "coo") << "Not valid adj format request."; CHECK(fmt == "coo") << "Not valid adj format request.";
if (transpose) { if (transpose) {
return {aten::HStack(adj_.col, adj_.row)}; return {aten::HStack(adj_.col, adj_.row)};
...@@ -350,9 +329,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -350,9 +329,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
} }
} }
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { return adj_; }
return adj_;
}
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override { aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for COO graph"; LOG(FATAL) << "Not enabled for COO graph";
...@@ -364,7 +341,8 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -364,7 +341,8 @@ class UnitGraph::COO : public BaseHeteroGraph {
return aten::CSRMatrix(); return aten::CSRMatrix();
} }
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for COO graph"; LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::kCOO; return SparseFormat::kCOO;
} }
...@@ -379,8 +357,10 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -379,8 +357,10 @@ class UnitGraph::COO : public BaseHeteroGraph {
return 0; return 0;
} }
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch"; const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes())
<< "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()]; auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
...@@ -388,15 +368,16 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -388,15 +368,16 @@ class UnitGraph::COO : public BaseHeteroGraph {
const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids); const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols, subg.graph = std::make_shared<COO>(
submat.row, submat.col); meta_graph(), submat.num_rows, submat.num_cols, submat.row, submat.col);
subg.induced_vertices = vids; subg.induced_vertices = vids;
subg.induced_edges.emplace_back(submat.data); subg.induced_edges.emplace_back(submat.data);
return subg; return subg;
} }
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override { const std::vector<IdArray>& eids,
bool preserve_nodes = false) const override {
CHECK_EQ(eids.size(), 1) << "Edge type number mismatch."; CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
HeteroSubgraph subg; HeteroSubgraph subg;
if (!preserve_nodes) { if (!preserve_nodes) {
...@@ -417,7 +398,8 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -417,7 +398,8 @@ class UnitGraph::COO : public BaseHeteroGraph {
subg.induced_vertices.emplace_back( subg.induced_vertices.emplace_back(
aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context())); aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
subg.graph = std::make_shared<COO>( subg.graph = std::make_shared<COO>(
meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst); meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src,
new_dst);
subg.induced_edges = eids; subg.induced_edges = eids;
} }
return subg; return subg;
...@@ -428,13 +410,11 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -428,13 +410,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return nullptr; return nullptr;
} }
aten::COOMatrix adj() const { aten::COOMatrix adj() const { return adj_; }
return adj_;
}
/** /**
* @brief Determines whether the graph is "hypersparse", i.e. having significantly more * @brief Determines whether the graph is "hypersparse", i.e. having
* nodes than edges. * significantly more nodes than edges.
*/ */
bool IsHypersparse() const { bool IsHypersparse() const {
return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) && return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
...@@ -470,50 +450,45 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -470,50 +450,45 @@ class UnitGraph::COO : public BaseHeteroGraph {
/** @brief CSR graph */ /** @brief CSR graph */
class UnitGraph::CSR : public BaseHeteroGraph { class UnitGraph::CSR : public BaseHeteroGraph {
public: public:
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids) IdArray indices, IdArray edge_ids)
: BaseHeteroGraph(metagraph) { : BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices)); CHECK(aten::IsValidIdArray(indices));
if (aten::IsValidIdArray(edge_ids)) if (aten::IsValidIdArray(edge_ids))
CHECK((indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids)) CHECK(
<< "edge id arrays should have the same length as indices if not empty"; (indices->shape[0] == edge_ids->shape[0]) ||
aten::IsNullArray(edge_ids))
<< "edge id arrays should have the same length as indices if not "
"empty";
CHECK_EQ(num_src, indptr->shape[0] - 1) CHECK_EQ(num_src, indptr->shape[0] - 1)
<< "number of nodes do not match the length of indptr minus 1."; << "number of nodes do not match the length of indptr minus 1.";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
} }
CSR(GraphPtr metagraph, const aten::CSRMatrix& csr) CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) { : BaseHeteroGraph(metagraph), adj_(csr) {}
}
CSR() { CSR() {
// set magic num_rows/num_cols to mark it as undefined // set magic num_rows/num_cols to mark it as undefined
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
// supported
adj_.num_rows = -1; adj_.num_rows = -1;
adj_.num_cols = -1; adj_.num_cols = -1;
}; };
bool defined() const { bool defined() const { return (adj_.num_rows >= 0) || (adj_.num_cols >= 0); }
return (adj_.num_rows >= 0) || (adj_.num_cols >= 0);
}
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const { return 0; }
return 0;
}
inline dgl_type_t DstType() const { inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
return NumVertexTypes() == 1? 0 : 1;
}
inline dgl_type_t EdgeType() const { inline dgl_type_t EdgeType() const { return 0; }
return 0;
}
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. " LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
<< "The relation graph is simply this graph itself."; << "The relation graph is simply this graph itself.";
return {}; return {};
} }
...@@ -529,33 +504,22 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -529,33 +504,22 @@ class UnitGraph::CSR : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
void Clear() override { void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
DGLDataType DataType() const override { DGLDataType DataType() const override { return adj_.indices->dtype; }
return adj_.indices->dtype;
}
DGLContext Context() const override { DGLContext Context() const override { return adj_.indices->ctx; }
return adj_.indices->ctx;
}
bool IsPinned() const override { bool IsPinned() const override { return adj_.is_pinned; }
return adj_.is_pinned;
}
uint8_t NumBits() const override { uint8_t NumBits() const override { return adj_.indices->dtype.bits; }
return adj_.indices->dtype.bits;
}
CSR AsNumBits(uint8_t bits) const { CSR AsNumBits(uint8_t bits) const {
if (NumBits() == bits) { if (NumBits() == bits) {
return *this; return *this;
} else { } else {
CSR ret( CSR ret(
meta_graph_, meta_graph_, adj_.num_rows, adj_.num_cols,
adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.indptr, bits), aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits), aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits)); aten::AsNumBits(adj_.data, bits));
...@@ -563,7 +527,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -563,7 +527,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
} }
CSR CopyTo(const DGLContext &ctx) const { CSR CopyTo(const DGLContext& ctx) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
...@@ -572,27 +536,19 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -572,27 +536,19 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
/** @brief Pin the adj_: CSRMatrix of the CSR graph. */ /** @brief Pin the adj_: CSRMatrix of the CSR graph. */
void PinMemory_() { void PinMemory_() { adj_.PinMemory_(); }
adj_.PinMemory_();
}
/** @brief Unpin the adj_: CSRMatrix of the CSR graph. */ /** @brief Unpin the adj_: CSRMatrix of the CSR graph. */
void UnpinMemory_() { void UnpinMemory_() { adj_.UnpinMemory_(); }
adj_.UnpinMemory_();
}
/** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */ /** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */
void RecordStream(DGLStreamHandle stream) override { void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream); adj_.RecordStream(stream);
} }
bool IsMultigraph() const override { bool IsMultigraph() const override { return aten::CSRHasDuplicate(adj_); }
return aten::CSRHasDuplicate(adj_);
}
bool IsReadonly() const override { bool IsReadonly() const override { return true; }
return true;
}
uint64_t NumVertices(dgl_type_t vtype) const override { uint64_t NumVertices(dgl_type_t vtype) const override {
if (vtype == SrcType()) { if (vtype == SrcType()) {
...@@ -618,13 +574,15 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -618,13 +574,15 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return {}; return {};
} }
bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override { bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src; CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst; CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::CSRIsNonZero(adj_, src, dst); return aten::CSRIsNonZero(adj_, src, dst);
} }
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override { BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::CSRIsNonZero(adj_, src_ids, dst_ids); return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
...@@ -646,18 +604,21 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -646,18 +604,21 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return aten::CSRGetAllData(adj_, src, dst); return aten::CSRGetAllData(adj_, src, dst);
} }
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override { EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const override {
CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst); const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
return EdgeArray{arrs[0], arrs[1], arrs[2]}; return EdgeArray{arrs[0], arrs[1], arrs[2]};
} }
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override { IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const override {
return aten::CSRGetData(adj_, src, dst); return aten::CSRGetData(adj_, src, dst);
} }
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const override {
LOG(FATAL) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
...@@ -681,7 +642,8 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -681,7 +642,8 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid; CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid); IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
IdArray ret_eid = aten::CSRGetRowData(adj_, vid); IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx); IdArray ret_src =
aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid}; return EdgeArray{ret_src, ret_dst, ret_eid};
} }
...@@ -695,10 +657,11 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -695,10 +657,11 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return EdgeArray{row, coosubmat.col, coosubmat.data}; return EdgeArray{row, coosubmat.col, coosubmat.data};
} }
EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override { EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const override {
CHECK(order.empty() || order == std::string("srcdst")) CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\"," << "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\"."; << " but got \"" << order << "\".";
auto coo = aten::CSRToCOO(adj_, false); auto coo = aten::CSRToCOO(adj_, false);
if (order == std::string("srcdst")) { if (order == std::string("srcdst")) {
// make sure the coo is sorted if an order is requested // make sure the coo is sorted if an order is requested
...@@ -770,7 +733,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -770,7 +733,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
std::vector<IdArray> GetAdj( std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override { dgl_type_t etype, bool transpose, const std::string& fmt) const override {
CHECK(!transpose && fmt == "csr") << "Not valid adj format request."; CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
return {adj_.indptr, adj_.indices, adj_.data}; return {adj_.indptr, adj_.indices, adj_.data};
} }
...@@ -785,11 +748,10 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -785,11 +748,10 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return aten::CSRMatrix(); return aten::CSRMatrix();
} }
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { return adj_; }
return adj_;
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for CSR graph"; LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::kCSR; return SparseFormat::kCSR;
} }
...@@ -804,8 +766,10 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -804,8 +766,10 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return 0; return 0;
} }
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch"; const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes())
<< "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()]; auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
...@@ -813,15 +777,17 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -813,15 +777,17 @@ class UnitGraph::CSR : public BaseHeteroGraph {
const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids); const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
DGLContext ctx = aten::GetContextOf(vids); DGLContext ctx = aten::GetContextOf(vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols, subg.graph = std::make_shared<CSR>(
submat.indptr, submat.indices, sub_eids); meta_graph(), submat.num_rows, submat.num_cols, submat.indptr,
submat.indices, sub_eids);
subg.induced_vertices = vids; subg.induced_vertices = vids;
subg.induced_edges.emplace_back(submat.data); subg.induced_edges.emplace_back(submat.data);
return subg; return subg;
} }
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override { const std::vector<IdArray>& eids,
bool preserve_nodes = false) const override {
LOG(FATAL) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
...@@ -831,9 +797,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -831,9 +797,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return nullptr; return nullptr;
} }
aten::CSRMatrix adj() const { aten::CSRMatrix adj() const { return adj_; }
return adj_;
}
bool Load(dmlc::Stream* fs) { bool Load(dmlc::Stream* fs) {
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>(); auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
...@@ -861,21 +825,13 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -861,21 +825,13 @@ class UnitGraph::CSR : public BaseHeteroGraph {
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
DGLDataType UnitGraph::DataType() const { DGLDataType UnitGraph::DataType() const { return GetAny()->DataType(); }
return GetAny()->DataType();
}
DGLContext UnitGraph::Context() const { DGLContext UnitGraph::Context() const { return GetAny()->Context(); }
return GetAny()->Context();
}
bool UnitGraph::IsPinned() const { bool UnitGraph::IsPinned() const { return GetAny()->IsPinned(); }
return GetAny()->IsPinned();
}
uint8_t UnitGraph::NumBits() const { uint8_t UnitGraph::NumBits() const { return GetAny()->NumBits(); }
return GetAny()->NumBits();
}
bool UnitGraph::IsMultigraph() const { bool UnitGraph::IsMultigraph() const {
const SparseFormat fmt = SelectFormat(CSC_CODE); const SparseFormat fmt = SelectFormat(CSC_CODE);
...@@ -910,7 +866,8 @@ BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const { ...@@ -910,7 +866,8 @@ BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
return aten::LT(vids, NumVertices(vtype)); return aten::LT(vids, NumVertices(vtype));
} }
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { bool UnitGraph::HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(CSC_CODE); const SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) if (fmt == SparseFormat::kCSC)
...@@ -953,7 +910,8 @@ IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { ...@@ -953,7 +910,8 @@ IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
return ptr->EdgeId(etype, src, dst); return ptr->EdgeId(etype, src, dst);
} }
EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const { EdgeArray UnitGraph::EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(CSR_CODE); const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) { if (fmt == SparseFormat::kCSC) {
...@@ -964,7 +922,8 @@ EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) cons ...@@ -964,7 +922,8 @@ EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) cons
} }
} }
IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const { IdArray UnitGraph::EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(CSR_CODE); const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) { if (fmt == SparseFormat::kCSC) {
...@@ -974,7 +933,8 @@ IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const ...@@ -974,7 +933,8 @@ IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const
} }
} }
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const { std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(
dgl_type_t etype, dgl_id_t eid) const {
const SparseFormat fmt = SelectFormat(COO_CODE); const SparseFormat fmt = SelectFormat(COO_CODE);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->FindEdge(etype, eid); return ptr->FindEdge(etype, eid);
...@@ -1020,7 +980,7 @@ EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const { ...@@ -1020,7 +980,7 @@ EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
return ptr->OutEdges(etype, vids); return ptr->OutEdges(etype, vids);
} }
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const { EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string& order) const {
SparseFormat fmt; SparseFormat fmt;
if (order == std::string("eid")) { if (order == std::string("eid")) {
fmt = SelectFormat(COO_CODE); fmt = SelectFormat(COO_CODE);
...@@ -1117,17 +1077,18 @@ DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const { ...@@ -1117,17 +1077,18 @@ DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
} }
std::vector<IdArray> UnitGraph::GetAdj( std::vector<IdArray> UnitGraph::GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const { dgl_type_t etype, bool transpose, const std::string& fmt) const {
// TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for // TODO(minjie): Our current semantics of adjacency matrix is row for dst
// src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False // nodes and col for src nodes. Therefore, we need to flip the transpose flag.
// is equal to in edge CSR. // For example,
// We have this behavior because previously we use framework's SPMM and we don't cache // transpose=False is equal to in edge CSR. We have this behavior because
// reverse adj. This is not intuitive and also not consistent with networkx's // previously we use framework's SPMM and we don't cache reverse adj. This
// to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the // is not intuitive and also not consistent with networkx's
// behavior and make row for src and col for dst. // to_scipy_sparse_matrix. With the upcoming custom kernel change, we should
// change the behavior and make row for src and col for dst.
if (fmt == std::string("csr")) { if (fmt == std::string("csr")) {
return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr") return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr")
: GetInCSR()->GetAdj(etype, false, "csr"); : GetInCSR()->GetAdj(etype, false, "csr");
} else if (fmt == std::string("coo")) { } else if (fmt == std::string("coo")) {
return GetCOO()->GetAdj(etype, transpose, fmt); return GetCOO()->GetAdj(etype, transpose, fmt);
} else { } else {
...@@ -1136,7 +1097,8 @@ std::vector<IdArray> UnitGraph::GetAdj( ...@@ -1136,7 +1097,8 @@ std::vector<IdArray> UnitGraph::GetAdj(
} }
} }
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const { HeteroSubgraph UnitGraph::VertexSubgraph(
const std::vector<IdArray>& vids) const {
// We prefer to generate a subgraph from out-csr. // We prefer to generate a subgraph from out-csr.
SparseFormat fmt = SelectFormat(CSR_CODE); SparseFormat fmt = SelectFormat(CSR_CODE);
HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids); HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
...@@ -1160,7 +1122,8 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const ...@@ -1160,7 +1122,8 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
return ret; return ret;
} }
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo)); ret.graph =
HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.induced_vertices = std::move(sg.induced_vertices); ret.induced_vertices = std::move(sg.induced_vertices);
ret.induced_edges = std::move(sg.induced_edges); ret.induced_edges = std::move(sg.induced_edges);
return ret; return ret;
...@@ -1190,80 +1153,67 @@ HeteroSubgraph UnitGraph::EdgeSubgraph( ...@@ -1190,80 +1153,67 @@ HeteroSubgraph UnitGraph::EdgeSubgraph(
return ret; return ret;
} }
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo)); ret.graph =
HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.induced_vertices = std::move(sg.induced_vertices); ret.induced_vertices = std::move(sg.induced_vertices);
ret.induced_edges = std::move(sg.induced_edges); ret.induced_edges = std::move(sg.induced_edges);
return ret; return ret;
} }
HeteroGraphPtr UnitGraph::CreateFromCOO( HeteroGraphPtr UnitGraph::CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray row, IdArray col, IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
bool row_sorted, bool col_sorted,
dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1) if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, num_src, num_dst, row, col, COOPtr coo(new COO(mg, num_src, num_dst, row, col, row_sorted, col_sorted));
row_sorted, col_sorted));
return HeteroGraphPtr( return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
new UnitGraph(mg, nullptr, nullptr, coo, formats));
} }
HeteroGraphPtr UnitGraph::CreateFromCOO( HeteroGraphPtr UnitGraph::CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1) if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, mat)); COOPtr coo(new COO(mg, mat));
return HeteroGraphPtr( return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
new UnitGraph(mg, nullptr, nullptr, coo, formats));
} }
HeteroGraphPtr UnitGraph::CreateFromCSR( HeteroGraphPtr UnitGraph::CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) { IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1) if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids)); CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
} }
HeteroGraphPtr UnitGraph::CreateFromCSR( HeteroGraphPtr UnitGraph::CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1) if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, mat)); CSRPtr csr(new CSR(mg, mat));
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
} }
HeteroGraphPtr UnitGraph::CreateFromCSC( HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) { IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1) if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids)); CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
} }
HeteroGraphPtr UnitGraph::CreateFromCSC( HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1) if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, mat)); CSRPtr csc(new CSR(mg, mat));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
...@@ -1275,18 +1225,21 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -1275,18 +1225,21 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
} else { } else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g); auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg); CHECK_NOTNULL(bg);
CSRPtr new_incsr = CSRPtr new_incsr = (bg->in_csr_->defined())
(bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr; ? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits)))
CSRPtr new_outcsr = : nullptr;
(bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr; CSRPtr new_outcsr = (bg->out_csr_->defined())
COOPtr new_coo = ? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits)))
(bg->coo_->defined())? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr; : nullptr;
return HeteroGraphPtr( COOPtr new_coo = (bg->coo_->defined())
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_)); ? COOPtr(new COO(bg->coo_->AsNumBits(bits)))
: nullptr;
return HeteroGraphPtr(new UnitGraph(
g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
} }
} }
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) { HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} else { } else {
...@@ -1301,54 +1254,43 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) { ...@@ -1301,54 +1254,43 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
COOPtr new_coo = (bg->coo_->defined()) COOPtr new_coo = (bg->coo_->defined())
? COOPtr(new COO(bg->coo_->CopyTo(ctx))) ? COOPtr(new COO(bg->coo_->CopyTo(ctx)))
: nullptr; : nullptr;
return HeteroGraphPtr( return HeteroGraphPtr(new UnitGraph(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_)); g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
} }
} }
void UnitGraph::PinMemory_() { void UnitGraph::PinMemory_() {
if (this->in_csr_->defined()) if (this->in_csr_->defined()) this->in_csr_->PinMemory_();
this->in_csr_->PinMemory_(); if (this->out_csr_->defined()) this->out_csr_->PinMemory_();
if (this->out_csr_->defined()) if (this->coo_->defined()) this->coo_->PinMemory_();
this->out_csr_->PinMemory_();
if (this->coo_->defined())
this->coo_->PinMemory_();
} }
void UnitGraph::UnpinMemory_() { void UnitGraph::UnpinMemory_() {
if (this->in_csr_->defined()) if (this->in_csr_->defined()) this->in_csr_->UnpinMemory_();
this->in_csr_->UnpinMemory_(); if (this->out_csr_->defined()) this->out_csr_->UnpinMemory_();
if (this->out_csr_->defined()) if (this->coo_->defined()) this->coo_->UnpinMemory_();
this->out_csr_->UnpinMemory_();
if (this->coo_->defined())
this->coo_->UnpinMemory_();
} }
void UnitGraph::RecordStream(DGLStreamHandle stream) { void UnitGraph::RecordStream(DGLStreamHandle stream) {
if (this->in_csr_->defined()) if (this->in_csr_->defined()) this->in_csr_->RecordStream(stream);
this->in_csr_->RecordStream(stream); if (this->out_csr_->defined()) this->out_csr_->RecordStream(stream);
if (this->out_csr_->defined()) if (this->coo_->defined()) this->coo_->RecordStream(stream);
this->out_csr_->RecordStream(stream);
if (this->coo_->defined())
this->coo_->RecordStream(stream);
this->recorded_streams.push_back(stream); this->recorded_streams.push_back(stream);
} }
void UnitGraph::InvalidateCSR() { void UnitGraph::InvalidateCSR() { this->out_csr_ = CSRPtr(new CSR()); }
this->out_csr_ = CSRPtr(new CSR());
}
void UnitGraph::InvalidateCSC() { void UnitGraph::InvalidateCSC() { this->in_csr_ = CSRPtr(new CSR()); }
this->in_csr_ = CSRPtr(new CSR());
}
void UnitGraph::InvalidateCOO() { void UnitGraph::InvalidateCOO() { this->coo_ = COOPtr(new COO()); }
this->coo_ = COOPtr(new COO());
}
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph::UnitGraph(
dgl_format_code_t formats) GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
: BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) { dgl_format_code_t formats)
: BaseHeteroGraph(metagraph),
in_csr_(in_csr),
out_csr_(out_csr),
coo_(coo) {
if (!in_csr_) { if (!in_csr_) {
in_csr_ = CSRPtr(new CSR()); in_csr_ = CSRPtr(new CSR());
} }
...@@ -1361,20 +1303,16 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c ...@@ -1361,20 +1303,16 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c
formats_ = formats; formats_ = formats;
dgl_format_code_t created = GetCreatedFormats(); dgl_format_code_t created = GetCreatedFormats();
if ((formats | created) != formats) if ((formats | created) != formats)
LOG(FATAL) << "Graph created from formats: " << CodeToStr(created) << LOG(FATAL) << "Graph created from formats: " << CodeToStr(created)
", which is not compatible with available formats: " << CodeToStr(formats); << ", which is not compatible with available formats: "
<< CodeToStr(formats);
CHECK(GetAny()) << "At least one graph structure should exist."; CHECK(GetAny()) << "At least one graph structure should exist.";
} }
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom( HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
int num_vtypes, int num_vtypes, const aten::CSRMatrix& in_csr,
const aten::CSRMatrix &in_csr, const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo, bool has_in_csr,
const aten::CSRMatrix &out_csr, bool has_out_csr, bool has_coo, dgl_format_code_t formats) {
const aten::COOMatrix &coo,
bool has_in_csr,
bool has_out_csr,
bool has_coo,
dgl_format_code_t formats) {
auto mg = CreateUnitGraphMetaGraph(num_vtypes); auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr in_csr_ptr = nullptr; CSRPtr in_csr_ptr = nullptr;
...@@ -1394,21 +1332,21 @@ HeteroGraphPtr UnitGraph::CreateUnitGraphFrom( ...@@ -1394,21 +1332,21 @@ HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
else else
coo_ptr = COOPtr(new COO()); coo_ptr = COOPtr(new COO());
return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats)); return HeteroGraphPtr(
new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
} }
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
if (inplace) if (inplace)
if (!(formats_ & CSC_CODE)) if (!(formats_ & CSC_CODE))
LOG(FATAL) << "The graph have restricted sparse format " << LOG(FATAL) << "The graph have restricted sparse format "
CodeToStr(formats_) << ", cannot create CSC matrix."; << CodeToStr(formats_) << ", cannot create CSC matrix.";
CSRPtr ret = in_csr_; CSRPtr ret = in_csr_;
// Prefers converting from COO since it is parallelized. // Prefers converting from COO since it is parallelized.
// TODO(BarclayII): need benchmarking. // TODO(BarclayII): need benchmarking.
if (!in_csr_->defined()) { if (!in_csr_->defined()) {
if (coo_->defined()) { if (coo_->defined()) {
const auto& newadj = aten::COOToCSR( const auto& newadj = aten::COOToCSR(aten::COOTranspose(coo_->adj()));
aten::COOTranspose(coo_->adj()));
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
...@@ -1424,10 +1362,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1424,10 +1362,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} }
if (inplace) { if (inplace) {
if (IsPinned()) if (IsPinned()) in_csr_->PinMemory_();
in_csr_->PinMemory_(); for (auto stream : recorded_streams) in_csr_->RecordStream(stream);
for (auto stream : recorded_streams)
in_csr_->RecordStream(stream);
} }
} }
return ret; return ret;
...@@ -1437,8 +1373,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1437,8 +1373,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
if (inplace) if (inplace)
if (!(formats_ & CSR_CODE)) if (!(formats_ & CSR_CODE))
LOG(FATAL) << "The graph have restricted sparse format " << LOG(FATAL) << "The graph have restricted sparse format "
CodeToStr(formats_) << ", cannot create CSR matrix."; << CodeToStr(formats_) << ", cannot create CSR matrix.";
CSRPtr ret = out_csr_; CSRPtr ret = out_csr_;
// Prefers converting from COO since it is parallelized. // Prefers converting from COO since it is parallelized.
// TODO(BarclayII): need benchmarking. // TODO(BarclayII): need benchmarking.
...@@ -1460,10 +1396,8 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1460,10 +1396,8 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} }
if (inplace) { if (inplace) {
if (IsPinned()) if (IsPinned()) out_csr_->PinMemory_();
out_csr_->PinMemory_(); for (auto stream : recorded_streams) out_csr_->RecordStream(stream);
for (auto stream : recorded_streams)
out_csr_->RecordStream(stream);
} }
} }
return ret; return ret;
...@@ -1473,12 +1407,13 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1473,12 +1407,13 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
if (inplace) if (inplace)
if (!(formats_ & COO_CODE)) if (!(formats_ & COO_CODE))
LOG(FATAL) << "The graph have restricted sparse format " << LOG(FATAL) << "The graph have restricted sparse format "
CodeToStr(formats_) << ", cannot create COO matrix."; << CodeToStr(formats_) << ", cannot create COO matrix.";
COOPtr ret = coo_; COOPtr ret = coo_;
if (!coo_->defined()) { if (!coo_->defined()) {
if (in_csr_->defined()) { if (in_csr_->defined()) {
const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true)); const auto& newadj =
aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
...@@ -1494,10 +1429,8 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { ...@@ -1494,10 +1429,8 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
ret = std::make_shared<COO>(meta_graph(), newadj); ret = std::make_shared<COO>(meta_graph(), newadj);
} }
if (inplace) { if (inplace) {
if (IsPinned()) if (IsPinned()) coo_->PinMemory_();
coo_->PinMemory_(); for (auto stream : recorded_streams) coo_->RecordStream(stream);
for (auto stream : recorded_streams)
coo_->RecordStream(stream);
} }
} }
return ret; return ret;
...@@ -1527,27 +1460,22 @@ HeteroGraphPtr UnitGraph::GetAny() const { ...@@ -1527,27 +1460,22 @@ HeteroGraphPtr UnitGraph::GetAny() const {
dgl_format_code_t UnitGraph::GetCreatedFormats() const { dgl_format_code_t UnitGraph::GetCreatedFormats() const {
dgl_format_code_t ret = 0; dgl_format_code_t ret = 0;
if (in_csr_->defined()) if (in_csr_->defined()) ret |= CSC_CODE;
ret |= CSC_CODE; if (out_csr_->defined()) ret |= CSR_CODE;
if (out_csr_->defined()) if (coo_->defined()) ret |= COO_CODE;
ret |= CSR_CODE;
if (coo_->defined())
ret |= COO_CODE;
return ret; return ret;
} }
dgl_format_code_t UnitGraph::GetAllowedFormats() const { dgl_format_code_t UnitGraph::GetAllowedFormats() const { return formats_; }
return formats_;
}
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const { HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
switch (format) { switch (format) {
case SparseFormat::kCSR: case SparseFormat::kCSR:
return GetOutCSR(); return GetOutCSR();
case SparseFormat::kCSC: case SparseFormat::kCSC:
return GetInCSR(); return GetInCSR();
default: default:
return GetCOO(); return GetCOO();
} }
} }
...@@ -1555,17 +1483,11 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const { ...@@ -1555,17 +1483,11 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
if (formats == ALL_CODE) if (formats == ALL_CODE)
return HeteroGraphPtr( return HeteroGraphPtr(
// TODO(xiangsx) Make it as graph storage.Clone() // TODO(xiangsx) Make it as graph storage.Clone()
new UnitGraph(meta_graph_, new UnitGraph(
(in_csr_->defined()) meta_graph_,
? CSRPtr(new CSR(*in_csr_)) (in_csr_->defined()) ? CSRPtr(new CSR(*in_csr_)) : nullptr,
: nullptr, (out_csr_->defined()) ? CSRPtr(new CSR(*out_csr_)) : nullptr,
(out_csr_->defined()) (coo_->defined()) ? COOPtr(new COO(*coo_)) : nullptr, formats));
? CSRPtr(new CSR(*out_csr_))
: nullptr,
(coo_->defined())
? COOPtr(new COO(*coo_))
: nullptr,
formats));
int64_t num_vtypes = NumVertexTypes(); int64_t num_vtypes = NumVertexTypes();
if (formats & COO_CODE) if (formats & COO_CODE)
return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats); return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
...@@ -1574,18 +1496,17 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const { ...@@ -1574,18 +1496,17 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats); return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
} }
SparseFormat UnitGraph::SelectFormat(dgl_format_code_t preferred_formats) const { SparseFormat UnitGraph::SelectFormat(
dgl_format_code_t preferred_formats) const {
dgl_format_code_t common = preferred_formats & formats_; dgl_format_code_t common = preferred_formats & formats_;
dgl_format_code_t created = GetCreatedFormats(); dgl_format_code_t created = GetCreatedFormats();
if (common & created) if (common & created) return DecodeFormat(common & created);
return DecodeFormat(common & created);
// NOTE(zihao): hypersparse is currently disabled since many CUDA operators on COO have // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on
// not been implmented yet. // COO have not been implmented yet. if (coo_->defined() &&
// if (coo_->defined() && coo_->IsHypersparse()) // only allow coo for hypersparse graph. // coo_->IsHypersparse()) // only allow coo for hypersparse graph.
// return SparseFormat::kCOO; // return SparseFormat::kCOO;
if (common) if (common) return DecodeFormat(common);
return DecodeFormat(common);
return DecodeFormat(created); return DecodeFormat(created);
} }
...@@ -1623,13 +1544,15 @@ HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const { ...@@ -1623,13 +1544,15 @@ HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
} }
case SparseFormat::kCSR: { case SparseFormat::kCSR: {
const aten::CSRMatrix csr = GetCSRMatrix(0); const aten::CSRMatrix csr = GetCSRMatrix(0);
const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking); const aten::COOMatrix coo =
aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
return CreateFromCOO(1, coo); return CreateFromCOO(1, coo);
} }
case SparseFormat::kCSC: { case SparseFormat::kCSC: {
const aten::CSRMatrix csc = GetCSCMatrix(0); const aten::CSRMatrix csc = GetCSCMatrix(0);
const aten::CSRMatrix csr = aten::CSRTranspose(csc); const aten::CSRMatrix csr = aten::CSRTranspose(csc);
const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking); const aten::COOMatrix coo =
aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
return CreateFromCOO(1, coo); return CreateFromCOO(1, coo);
} }
default: default:
...@@ -1662,21 +1585,21 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1662,21 +1585,21 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
} else { } else {
// NOTE(zihao): to be compatible with old formats. // NOTE(zihao): to be compatible with old formats.
switch (formats_code & 0xffffffff) { switch (formats_code & 0xffffffff) {
case 0: case 0:
formats_ = ALL_CODE; formats_ = ALL_CODE;
break; break;
case 1: case 1:
formats_ = COO_CODE; formats_ = COO_CODE;
break; break;
case 2: case 2:
formats_ = CSR_CODE; formats_ = CSR_CODE;
break; break;
case 3: case 3:
formats_ = CSC_CODE; formats_ = CSC_CODE;
break; break;
default: default:
LOG(FATAL) << "Load graph failed, formats code " << formats_code << LOG(FATAL) << "Load graph failed, formats code " << formats_code
"not recognized."; << "not recognized.";
} }
} }
...@@ -1708,13 +1631,12 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1708,13 +1631,12 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
return true; return true;
} }
void UnitGraph::Save(dmlc::Stream* fs) const { void UnitGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_UnitGraphMagic); fs->Write(kDGLSerialize_UnitGraphMagic);
// Didn't write UnitGraph::meta_graph_, since it's included in the underlying // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
// sparse matrix // sparse matrix
auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)}); auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)});
auto fstream = dynamic_cast<dgl::serialize::DGLStream *>(fs); auto fstream = dynamic_cast<dgl::serialize::DGLStream*>(fs);
if (fstream) { if (fstream) {
auto formats = fstream->FormatsToSave(); auto formats = fstream->FormatsToSave();
save_formats = formats == ANY_CODE save_formats = formats == ANY_CODE
...@@ -1738,14 +1660,15 @@ UnitGraphPtr UnitGraph::Reverse() const { ...@@ -1738,14 +1660,15 @@ UnitGraphPtr UnitGraph::Reverse() const {
CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_; CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
COOPtr new_coo = nullptr; COOPtr new_coo = nullptr;
if (coo_->defined()) { if (coo_->defined()) {
new_coo = COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj()))); new_coo =
COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
} }
return UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)); return UnitGraphPtr(
new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
} }
std::tuple<UnitGraphPtr, IdArray, IdArray> std::tuple<UnitGraphPtr, IdArray, IdArray> UnitGraph::ToSimple() const {
UnitGraph::ToSimple() const {
CSRPtr new_incsr = nullptr, new_outcsr = nullptr; CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
COOPtr new_coo = nullptr; COOPtr new_coo = nullptr;
IdArray count; IdArray count;
...@@ -1779,9 +1702,9 @@ UnitGraph::ToSimple() const { ...@@ -1779,9 +1702,9 @@ UnitGraph::ToSimple() const {
break; break;
} }
return std::make_tuple(UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)), return std::make_tuple(
count, UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
edge_map); count, edge_map);
} }
} // namespace dgl } // namespace dgl
...@@ -7,16 +7,17 @@ ...@@ -7,16 +7,17 @@
#ifndef DGL_GRAPH_UNIT_GRAPH_H_ #ifndef DGL_GRAPH_UNIT_GRAPH_H_
#define DGL_GRAPH_UNIT_GRAPH_H_ #define DGL_GRAPH_UNIT_GRAPH_H_
#include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/lazy.h> #include <dgl/lazy.h>
#include <dgl/array.h>
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/type_traits.h> #include <dmlc/type_traits.h>
#include <utility>
#include <string>
#include <vector>
#include <memory> #include <memory>
#include <string>
#include <tuple> #include <tuple>
#include <utility>
#include <vector>
#include "../c_api_common.h" #include "../c_api_common.h"
...@@ -45,21 +46,15 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -45,21 +46,15 @@ class UnitGraph : public BaseHeteroGraph {
typedef std::shared_ptr<COO> COOPtr; typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr; typedef std::shared_ptr<CSR> CSRPtr;
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const { return 0; }
return 0;
}
inline dgl_type_t DstType() const { inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
return NumVertexTypes() == 1? 0 : 1;
}
inline dgl_type_t EdgeType() const { inline dgl_type_t EdgeType() const { return 0; }
return 0;
}
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. " LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
<< "The relation graph is simply this graph itself."; << "The relation graph is simply this graph itself.";
return {}; return {};
} }
...@@ -75,9 +70,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -75,9 +70,7 @@ class UnitGraph : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable."; LOG(FATAL) << "UnitGraph graph is not mutable.";
} }
void Clear() override { void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
DGLDataType DataType() const override; DGLDataType DataType() const override;
...@@ -89,9 +82,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -89,9 +82,7 @@ class UnitGraph : public BaseHeteroGraph {
bool IsMultigraph() const override; bool IsMultigraph() const override;
bool IsReadonly() const override { bool IsReadonly() const override { return true; }
return true;
}
uint64_t NumVertices(dgl_type_t vtype) const override; uint64_t NumVertices(dgl_type_t vtype) const override;
...@@ -108,9 +99,11 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -108,9 +99,11 @@ class UnitGraph : public BaseHeteroGraph {
BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override; BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;
bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override; bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override; BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override;
IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override; IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override;
...@@ -118,11 +111,13 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -118,11 +111,13 @@ class UnitGraph : public BaseHeteroGraph {
IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override; IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override; EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const override;
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override; IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override;
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override; std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const override;
EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override; EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override;
...@@ -134,7 +129,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -134,7 +129,8 @@ class UnitGraph : public BaseHeteroGraph {
EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override; EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override;
EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override; EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const override;
uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override; uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override;
...@@ -156,18 +152,20 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -156,18 +152,20 @@ class UnitGraph : public BaseHeteroGraph {
DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override; DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;
std::vector<IdArray> GetAdj( std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override; dgl_type_t etype, bool transpose, const std::string& fmt) const override;
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override; HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const override;
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override; const std::vector<IdArray>& eids,
bool preserve_nodes = false) const override;
// creators // creators
/** @brief Create a graph with no edges */ /** @brief Create a graph with no edges */
static HeteroGraphPtr Empty( static HeteroGraphPtr Empty(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, DGLDataType dtype,
DGLDataType dtype, DGLContext ctx) { DGLContext ctx) {
IdArray row = IdArray::Empty({0}, dtype, ctx); IdArray row = IdArray::Empty({0}, dtype, ctx);
IdArray col = IdArray::Empty({0}, dtype, ctx); IdArray col = IdArray::Empty({0}, dtype, ctx);
return CreateFromCOO(num_vtypes, num_src, num_dst, row, col); return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
...@@ -175,9 +173,9 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -175,9 +173,9 @@ class UnitGraph : public BaseHeteroGraph {
/** @brief Create a graph from COO arrays */ /** @brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray row, IdArray col, bool row_sorted = false, IdArray col, bool row_sorted = false, bool col_sorted = false,
bool col_sorted = false, dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat,
...@@ -185,9 +183,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -185,9 +183,8 @@ class UnitGraph : public BaseHeteroGraph {
/** @brief Create a graph from (out) CSR arrays */ /** @brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
...@@ -195,9 +192,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -195,9 +192,8 @@ class UnitGraph : public BaseHeteroGraph {
/** @brief Create a graph from (in) CSC arrays */ /** @brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC( static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCSC( static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
...@@ -207,25 +203,23 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -207,25 +203,23 @@ class UnitGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/** @brief Copy the data to another context */ /** @brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext& ctx);
/** /**
* @brief Pin the in_csr_, out_scr_ and coo_ of the current graph. * @brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
* @note The graph will be pinned inplace. Behavior depends on the current context, * @note The graph will be pinned inplace. Behavior depends on the current
* kDGLCPU: will be pinned; * context, kDGLCPU: will be pinned; IsPinned: directly return; kDGLCUDA:
* IsPinned: directly return; * invalid, will throw an error. The context check is deferred to pinning the
* kDGLCUDA: invalid, will throw an error. * NDArray.
* The context check is deferred to pinning the NDArray. */
*/
void PinMemory_() override; void PinMemory_() override;
/** /**
* @brief Unpin the in_csr_, out_scr_ and coo_ of the current graph. * @brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
* @note The graph will be unpinned inplace. Behavior depends on the current context, * @note The graph will be unpinned inplace. Behavior depends on the current
* IsPinned: will be unpinned; * context, IsPinned: will be unpinned; others: directly return. The context
* others: directly return. * check is deferred to unpinning the NDArray.
* The context check is deferred to unpinning the NDArray. */
*/
void UnpinMemory_(); void UnpinMemory_();
/** /**
...@@ -234,26 +228,31 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -234,26 +228,31 @@ class UnitGraph : public BaseHeteroGraph {
*/ */
void RecordStream(DGLStreamHandle stream) override; void RecordStream(DGLStreamHandle stream) override;
/** /**
* @brief Create in-edge CSR format of the unit graph. * @brief Create in-edge CSR format of the unit graph.
* @param inplace if true and the in-edge CSR format does not exist, the created * @param inplace if true and the in-edge CSR format does not exist, the
* format will be cached in this object unless the format is restricted. * created format will be cached in this object unless the format is
* @return Return the in-edge CSR format. Create from other format if not exist. * restricted.
* @return Return the in-edge CSR format. Create from other format if not
* exist.
*/ */
CSRPtr GetInCSR(bool inplace = true) const; CSRPtr GetInCSR(bool inplace = true) const;
/** /**
* @brief Create out-edge CSR format of the unit graph. * @brief Create out-edge CSR format of the unit graph.
* @param inplace if true and the out-edge CSR format does not exist, the created * @param inplace if true and the out-edge CSR format does not exist, the
* format will be cached in this object unless the format is restricted. * created format will be cached in this object unless the format is
* @return Return the out-edge CSR format. Create from other format if not exist. * restricted.
* @return Return the out-edge CSR format. Create from other format if not
* exist.
*/ */
CSRPtr GetOutCSR(bool inplace = true) const; CSRPtr GetOutCSR(bool inplace = true) const;
/** /**
* @brief Create COO format of the unit graph. * @brief Create COO format of the unit graph.
* @param inplace if true and the COO format does not exist, the created * @param inplace if true and the COO format does not exist, the created
* format will be cached in this object unless the format is restricted. * format will be cached in this object unless the format is
* restricted.
* @return Return the COO format. Create from other format if not exist. * @return Return the COO format. Create from other format if not exist.
*/ */
COOPtr GetCOO(bool inplace = true) const; COOPtr GetCOO(bool inplace = true) const;
...@@ -267,13 +266,14 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -267,13 +266,14 @@ class UnitGraph : public BaseHeteroGraph {
/** @return Return the out-edge CSR in the matrix form */ /** @return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override; aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override { SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return SelectFormat(preferred_formats); return SelectFormat(preferred_formats);
} }
/** /**
* @brief Return the graph in the given format. Perform format conversion if the * @brief Return the graph in the given format. Perform format conversion if
* requested format does not exist. * the requested format does not exist.
* *
* @return A graph in the requested format. * @return A graph in the requested format.
*/ */
...@@ -298,11 +298,11 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -298,11 +298,11 @@ class UnitGraph : public BaseHeteroGraph {
UnitGraphPtr Reverse() const; UnitGraphPtr Reverse() const;
/** @return the simpled (no-multi-edge) graph /** @return the simpled (no-multi-edge) graph
* the count recording the number of duplicated edges from the original graph. * the count recording the number of duplicated edges from the
* the edge mapping from the edge IDs of original graph to those of the * original graph. the edge mapping from the edge IDs of original graph to
* returned graph. * those of the returned graph.
*/ */
std::tuple<UnitGraphPtr, IdArray, IdArray>ToSimple() const; std::tuple<UnitGraphPtr, IdArray, IdArray> ToSimple() const;
void InvalidateCSR(); void InvalidateCSR();
...@@ -326,8 +326,9 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -326,8 +326,9 @@ class UnitGraph : public BaseHeteroGraph {
* @param out_csr out edge csr * @param out_csr out edge csr
* @param coo coo * @param coo coo
*/ */
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph(
dgl_format_code_t formats = ALL_CODE); GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
dgl_format_code_t formats = ALL_CODE);
/** /**
* @brief constructor * @brief constructor
...@@ -341,13 +342,9 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -341,13 +342,9 @@ class UnitGraph : public BaseHeteroGraph {
* @param has_coo whether coo is valid * @param has_coo whether coo is valid
*/ */
static HeteroGraphPtr CreateUnitGraphFrom( static HeteroGraphPtr CreateUnitGraphFrom(
int num_vtypes, int num_vtypes, const aten::CSRMatrix& in_csr,
const aten::CSRMatrix &in_csr, const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo,
const aten::CSRMatrix &out_csr, bool has_in_csr, bool has_out_csr, bool has_coo,
const aten::COOMatrix &coo,
bool has_in_csr,
bool has_out_csr,
bool has_coo,
dgl_format_code_t formats = ALL_CODE); dgl_format_code_t formats = ALL_CODE);
/** @return Return any existing format. */ /** @return Return any existing format. */
...@@ -356,8 +353,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -356,8 +353,8 @@ class UnitGraph : public BaseHeteroGraph {
/** /**
* @brief Determine which format to use with a preference. * @brief Determine which format to use with a preference.
* *
* If the storage of unit graph is "locked", i.e. no conversion is allowed, then * If the storage of unit graph is "locked", i.e. no conversion is allowed,
* it will return the locked format. * then it will return the locked format.
* *
* Otherwise, it will return whatever DGL thinks is the most appropriate given * Otherwise, it will return whatever DGL thinks is the most appropriate given
* the arguments. * the arguments.
...@@ -369,8 +366,8 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -369,8 +366,8 @@ class UnitGraph : public BaseHeteroGraph {
GraphPtr AsImmutableGraph() const override; GraphPtr AsImmutableGraph() const override;
// Graph stored in different format. We use an on-demand strategy: the format is // Graph stored in different format. We use an on-demand strategy: the format
// only materialized if the operation that suitable for it is invoked. // is only materialized if the operation that suitable for it is invoked.
/** @brief CSR graph that stores reverse edges */ /** @brief CSR graph that stores reverse edges */
CSRPtr in_csr_; CSRPtr in_csr_;
/** @brief CSR representation */ /** @brief CSR representation */
......
...@@ -6,8 +6,10 @@ ...@@ -6,8 +6,10 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#include "sample_utils.h" #include "sample_utils.h"
namespace dgl { namespace dgl {
...@@ -27,12 +29,12 @@ template int32_t RandomEngine::Choice<int32_t>(FloatArray); ...@@ -27,12 +29,12 @@ template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray); template int64_t RandomEngine::Choice<int64_t>(FloatArray);
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out, void RandomEngine::Choice(
bool replace) { IdxType num, FloatArray prob, IdxType* out, bool replace) {
const IdxType N = prob->shape[0]; const IdxType N = prob->shape[0];
if (!replace) if (!replace)
CHECK_LE(num, N) CHECK_LE(num, N)
<< "Cannot take more sample than population when 'replace=false'"; << "Cannot take more sample than population when 'replace=false'";
if (num == N && !replace) std::iota(out, out + num, 0); if (num == N && !replace) std::iota(out, out + num, 0);
utils::BaseSampler<IdxType>* sampler = nullptr; utils::BaseSampler<IdxType>* sampler = nullptr;
...@@ -45,35 +47,31 @@ void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out, ...@@ -45,35 +47,31 @@ void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out,
delete sampler; delete sampler;
} }
template void RandomEngine::Choice<int32_t, float>(int32_t num, FloatArray prob, template void RandomEngine::Choice<int32_t, float>(
int32_t* out, bool replace); int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>(int64_t num, FloatArray prob, template void RandomEngine::Choice<int64_t, float>(
int64_t* out, bool replace); int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>(int32_t num, template void RandomEngine::Choice<int32_t, double>(
FloatArray prob, int32_t num, FloatArray prob, int32_t* out, bool replace);
int32_t* out, bool replace); template void RandomEngine::Choice<int64_t, double>(
template void RandomEngine::Choice<int64_t, double>(int64_t num, int64_t num, FloatArray prob, int64_t* out, bool replace);
FloatArray prob, template void RandomEngine::Choice<int32_t, int8_t>(
int64_t* out, bool replace); int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int32_t, int8_t>(int32_t num, FloatArray prob, template void RandomEngine::Choice<int64_t, int8_t>(
int32_t* out, bool replace); int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int64_t, int8_t>(int64_t num, FloatArray prob, template void RandomEngine::Choice<int32_t, uint8_t>(
int64_t* out, bool replace); int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int32_t, uint8_t>(int32_t num, template void RandomEngine::Choice<int64_t, uint8_t>(
FloatArray prob, int64_t num, FloatArray prob, int64_t* out, bool replace);
int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, uint8_t>(int64_t num,
FloatArray prob,
int64_t* out, bool replace);
template <typename IdxType> template <typename IdxType>
void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out, void RandomEngine::UniformChoice(
bool replace) { IdxType num, IdxType population, IdxType* out, bool replace) {
CHECK_GE(num, 0) << "The numbers to sample should be non-negative."; CHECK_GE(num, 0) << "The numbers to sample should be non-negative.";
CHECK_GE(population, 0) << "The population size should be non-negative."; CHECK_GE(population, 0) << "The population size should be non-negative.";
if (!replace) if (!replace)
CHECK_LE(num, population) CHECK_LE(num, population)
<< "Cannot take more sample than population when 'replace=false'"; << "Cannot take more sample than population when 'replace=false'";
if (replace) { if (replace) {
for (IdxType i = 0; i < num; ++i) out[i] = RandInt(population); for (IdxType i = 0; i < num; ++i) out[i] = RandInt(population);
} else { } else {
...@@ -112,14 +110,15 @@ void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out, ...@@ -112,14 +110,15 @@ void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
} }
} else { } else {
// In this case, `num >= population / 10`. To reduce the computation overhead, // In this case, `num >= population / 10`. To reduce the computation
// we should reduce the number of random number generations. Even though // overhead, we should reduce the number of random number generations.
// reservior algorithm is more memory effficient (it has O(num) memory complexity), // Even though reservior algorithm is more memory effficient (it has
// it generates O(population) random numbers, which is computationally expensive. // O(num) memory complexity), it generates O(population) random numbers,
// This algorithm has memory complexity of O(population) but generates much fewer random // which is computationally expensive. This algorithm has memory
// numbers O(num). In the case of `num >= population/10`, we don't need to worry about // complexity of O(population) but generates much fewer random numbers
// memory complexity because `num` is usually small. So is `population`. Allocating a small // O(num). In the case of `num >= population/10`, we don't need to worry
// piece of memory is very efficient. // about memory complexity because `num` is usually small. So is
// `population`. Allocating a small piece of memory is very efficient.
std::vector<IdxType> seq(population); std::vector<IdxType> seq(population);
for (size_t i = 0; i < seq.size(); i++) seq[i] = i; for (size_t i = 0; i < seq.size(); i++) seq[i] = i;
for (IdxType i = 0; i < num; i++) { for (IdxType i = 0; i < num; i++) {
...@@ -134,23 +133,22 @@ void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out, ...@@ -134,23 +133,22 @@ void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
} }
} }
template void RandomEngine::UniformChoice<int32_t>(int32_t num, template void RandomEngine::UniformChoice<int32_t>(
int32_t population, int32_t num, int32_t population, int32_t* out, bool replace);
int32_t* out, bool replace); template void RandomEngine::UniformChoice<int64_t>(
template void RandomEngine::UniformChoice<int64_t>(int64_t num, int64_t num, int64_t population, int64_t* out, bool replace);
int64_t population,
int64_t* out, bool replace);
template <typename IdxType, typename FloatType> template <typename IdxType, typename FloatType>
void RandomEngine::BiasedChoice( void RandomEngine::BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace) { IdxType num, const IdxType* split, FloatArray bias, IdxType* out,
bool replace) {
const int64_t num_tags = bias->shape[0]; const int64_t num_tags = bias->shape[0];
const FloatType *bias_data = static_cast<FloatType *>(bias->data); const FloatType* bias_data = static_cast<FloatType*>(bias->data);
IdxType total_node_num = 0; IdxType total_node_num = 0;
FloatArray prob = NDArray::Empty({num_tags}, bias->dtype, bias->ctx); FloatArray prob = NDArray::Empty({num_tags}, bias->dtype, bias->ctx);
FloatType *prob_data = static_cast<FloatType *>(prob->data); FloatType* prob_data = static_cast<FloatType*>(prob->data);
for (int64_t tag = 0 ; tag < num_tags; ++tag) { for (int64_t tag = 0; tag < num_tags; ++tag) {
int64_t tag_num_nodes = split[tag+1] - split[tag]; int64_t tag_num_nodes = split[tag + 1] - split[tag];
total_node_num += tag_num_nodes; total_node_num += tag_num_nodes;
FloatType tag_bias = bias_data[tag]; FloatType tag_bias = bias_data[tag];
prob_data[tag] = tag_num_nodes * tag_bias; prob_data[tag] = tag_num_nodes * tag_bias;
...@@ -159,19 +157,21 @@ void RandomEngine::BiasedChoice( ...@@ -159,19 +157,21 @@ void RandomEngine::BiasedChoice(
auto sampler = utils::TreeSampler<IdxType, FloatType, true>(this, prob); auto sampler = utils::TreeSampler<IdxType, FloatType, true>(this, prob);
for (IdxType i = 0; i < num; ++i) { for (IdxType i = 0; i < num; ++i) {
const int64_t tag = sampler.Draw(); const int64_t tag = sampler.Draw();
const IdxType tag_num_nodes = split[tag+1] - split[tag]; const IdxType tag_num_nodes = split[tag + 1] - split[tag];
out[i] = RandInt(tag_num_nodes) + split[tag]; out[i] = RandInt(tag_num_nodes) + split[tag];
} }
} else { } else {
utils::TreeSampler<int64_t, FloatType, false> sampler(this, prob, bias_data); utils::TreeSampler<int64_t, FloatType, false> sampler(
this, prob, bias_data);
CHECK_GE(total_node_num, num) CHECK_GE(total_node_num, num)
<< "Cannot take more sample than population when 'replace=false'"; << "Cannot take more sample than population when 'replace=false'";
// we use hash set here. Maybe in the future we should support reservoir algorithm // we use hash set here. Maybe in the future we should support reservoir
// algorithm
std::vector<std::unordered_set<IdxType>> selected(num_tags); std::vector<std::unordered_set<IdxType>> selected(num_tags);
for (IdxType i = 0 ; i < num ; ++i) { for (IdxType i = 0; i < num; ++i) {
const int64_t tag = sampler.Draw(); const int64_t tag = sampler.Draw();
bool inserted = false; bool inserted = false;
const IdxType tag_num_nodes = split[tag+1] - split[tag]; const IdxType tag_num_nodes = split[tag + 1] - split[tag];
IdxType selected_node; IdxType selected_node;
while (!inserted) { while (!inserted) {
CHECK_LT(selected[tag].size(), tag_num_nodes) CHECK_LT(selected[tag].size(), tag_num_nodes)
......
...@@ -6,15 +6,16 @@ ...@@ -6,15 +6,16 @@
#ifndef DGL_RANDOM_CPU_SAMPLE_UTILS_H_ #ifndef DGL_RANDOM_CPU_SAMPLE_UTILS_H_
#define DGL_RANDOM_CPU_SAMPLE_UTILS_H_ #define DGL_RANDOM_CPU_SAMPLE_UTILS_H_
#include <dgl/random.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/random.h>
#include <algorithm> #include <algorithm>
#include <utility>
#include <queue>
#include <cstdlib>
#include <cmath> #include <cmath>
#include <numeric> #include <cstdlib>
#include <limits> #include <limits>
#include <numeric>
#include <queue>
#include <utility>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
...@@ -32,34 +33,33 @@ class BaseSampler { ...@@ -32,34 +33,33 @@ class BaseSampler {
} }
}; };
// (BarclayII 2022.9.20) Changing the internal data type of probabilities to double since // (BarclayII 2022.9.20) Changing the internal data type of probabilities to
// we are using non-uniform sampling to sample on boolean masks, where False represents // double since we are using non-uniform sampling to sample on boolean masks,
// probability 0. DType could be uint8 in this case, which will give incorrect arithmetic // where False represents probability 0. DType could be uint8 in this case,
// results due to overflowing and/or integer division. // which will give incorrect arithmetic results due to overflowing and/or
// integer division.
/** /**
* AliasSampler is used to sample elements from a given discrete categorical distribution. * AliasSampler is used to sample elements from a given discrete categorical
* Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method) * distribution. Algorithm: Alias
* Sampler building complexity: O(n) * Method(https://en.wikipedia.org/wiki/Alias_method) Sampler building
* Sample w/ replacement complexity: O(1) * complexity: O(n) Sample w/ replacement complexity: O(1) Sample w/o
* Sample w/o replacement complexity: O(log n) * replacement complexity: O(log n)
*/ */
template < template <typename Idx, typename DType, bool replace>
typename Idx, class AliasSampler : public BaseSampler<Idx> {
typename DType,
bool replace>
class AliasSampler: public BaseSampler<Idx> {
private: private:
RandomEngine *re; RandomEngine *re;
Idx N; Idx N;
double accum, taken; // accumulated likelihood double accum, taken; // accumulated likelihood
std::vector<Idx> K; // alias table std::vector<Idx> K; // alias table
std::vector<double> U; // probability table std::vector<double> U; // probability table
FloatArray _prob; // category distribution FloatArray _prob; // category distribution
std::vector<bool> used; // indicate availability, activated when replace=false; std::vector<bool>
std::vector<Idx> id_mapping; // index mapping, activated when replace=false; used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // index mapping, activated when replace=false;
inline Idx Map(Idx x) const { // Map consecutive indices to unused elements inline Idx Map(Idx x) const { // Map consecutive indices to unused elements
if (replace) if (replace)
return x; return x;
else else
...@@ -72,20 +72,20 @@ class AliasSampler: public BaseSampler<Idx> { ...@@ -72,20 +72,20 @@ class AliasSampler: public BaseSampler<Idx> {
N = 0; N = 0;
accum = 0.; accum = 0.;
taken = 0.; taken = 0.;
if (!replace) if (!replace) id_mapping.clear();
id_mapping.clear();
for (Idx i = 0; i < prob_size; ++i) for (Idx i = 0; i < prob_size; ++i)
if (!used[i]) { if (!used[i]) {
N++; N++;
accum += prob_data[i]; accum += prob_data[i];
if (!replace) if (!replace) id_mapping.push_back(i);
id_mapping.push_back(i);
} }
if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'"; if (N == 0)
LOG(FATAL)
<< "Cannot take more sample than population when 'replace=false'";
K.resize(N); K.resize(N);
U.resize(N); U.resize(N);
double avg = accum / static_cast<double>(N); double avg = accum / static_cast<double>(N);
std::fill(U.begin(), U.end(), avg); // initialize U std::fill(U.begin(), U.end(), avg); // initialize U
std::queue<std::pair<Idx, double> > under, over; std::queue<std::pair<Idx, double> > under, over;
for (Idx i = 0; i < N; ++i) { for (Idx i = 0; i < N; ++i) {
double p = prob_data[Map(i)]; double p = prob_data[Map(i)];
...@@ -93,7 +93,7 @@ class AliasSampler: public BaseSampler<Idx> { ...@@ -93,7 +93,7 @@ class AliasSampler: public BaseSampler<Idx> {
over.push(std::make_pair(i, p)); over.push(std::make_pair(i, p));
else else
under.push(std::make_pair(i, p)); under.push(std::make_pair(i, p));
K[i] = i; // initialize K K[i] = i; // initialize K
} }
while (!under.empty() && !over.empty()) { while (!under.empty() && !over.empty()) {
auto u_pair = under.front(), o_pair = over.front(); auto u_pair = under.front(), o_pair = over.front();
...@@ -113,13 +113,12 @@ class AliasSampler: public BaseSampler<Idx> { ...@@ -113,13 +113,12 @@ class AliasSampler: public BaseSampler<Idx> {
public: public:
void ResetState(FloatArray prob) { void ResetState(FloatArray prob) {
used.resize(prob->shape[0]); used.resize(prob->shape[0]);
if (!replace) if (!replace) _prob = prob;
_prob = prob;
std::fill(used.begin(), used.end(), false); std::fill(used.begin(), used.end(), false);
Reconstruct(prob); Reconstruct(prob);
} }
explicit AliasSampler(RandomEngine* re, FloatArray prob): re(re) { explicit AliasSampler(RandomEngine *re, FloatArray prob) : re(re) {
ResetState(prob); ResetState(prob);
} }
...@@ -128,11 +127,10 @@ class AliasSampler: public BaseSampler<Idx> { ...@@ -128,11 +127,10 @@ class AliasSampler: public BaseSampler<Idx> {
Idx Draw() { Idx Draw() {
if (!replace) { if (!replace) {
const DType *_prob_data = _prob.Ptr<DType>(); const DType *_prob_data = _prob.Ptr<DType>();
if (2 * taken >= accum) if (2 * taken >= accum) Reconstruct(_prob);
Reconstruct(_prob); if (accum <= 0) return -1;
if (accum <= 0) // accum changes after Reconstruct(), so avg should be computed after
return -1; // that.
// accum changes after Reconstruct(), so avg should be computed after that.
double avg = accum / N; double avg = accum / N;
while (true) { while (true) {
double dice = re->Uniform<double>(0, N); double dice = re->Uniform<double>(0, N);
...@@ -151,42 +149,40 @@ class AliasSampler: public BaseSampler<Idx> { ...@@ -151,42 +149,40 @@ class AliasSampler: public BaseSampler<Idx> {
} }
} }
} }
if (accum <= 0) if (accum <= 0) return -1;
return -1;
double avg = accum / N; double avg = accum / N;
double dice = re->Uniform<double>(0, N); double dice = re->Uniform<double>(0, N);
Idx i = static_cast<Idx>(dice); Idx i = static_cast<Idx>(dice);
double p = (dice - i) * avg; double p = (dice - i) * avg;
if (p <= U[i]) if (p <= U[i])
return Map(i); return Map(i);
else else
return Map(K[i]); return Map(K[i]);
} }
}; };
/** /**
* CDFSampler is used to sample elements from a given discrete categorical distribution. * CDFSampler is used to sample elements from a given discrete categorical
* Algorithm: create a cumulative distribution function and conduct binary search for sampling. * distribution. Algorithm: create a cumulative distribution function and
* Reference: https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804 * conduct binary search for sampling. Reference:
* https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
* Sampler building complexity: O(n) * Sampler building complexity: O(n)
* Sample w/ and w/o replacement complexity: O(log n) * Sample w/ and w/o replacement complexity: O(log n)
*/ */
template < template <typename Idx, typename DType, bool replace>
typename Idx, class CDFSampler : public BaseSampler<Idx> {
typename DType,
bool replace>
class CDFSampler: public BaseSampler<Idx> {
private: private:
RandomEngine *re; RandomEngine *re;
Idx N; Idx N;
double accum, taken; double accum, taken;
FloatArray _prob; // categorical distribution FloatArray _prob; // categorical distribution
std::vector<double> cdf; // cumulative distribution function std::vector<double> cdf; // cumulative distribution function
std::vector<bool> used; // indicate availability, activated when replace=false; std::vector<bool>
std::vector<Idx> id_mapping; // indicate index mapping, activated when replace=false; used; // indicate availability, activated when replace=false;
std::vector<Idx>
id_mapping; // indicate index mapping, activated when replace=false;
inline Idx Map(Idx x) const { // Map consecutive indices to unused elements inline Idx Map(Idx x) const { // Map consecutive indices to unused elements
if (replace) if (replace)
return x; return x;
else else
...@@ -199,31 +195,30 @@ class CDFSampler: public BaseSampler<Idx> { ...@@ -199,31 +195,30 @@ class CDFSampler: public BaseSampler<Idx> {
N = 0; N = 0;
accum = 0.; accum = 0.;
taken = 0.; taken = 0.;
if (!replace) if (!replace) id_mapping.clear();
id_mapping.clear();
cdf.clear(); cdf.clear();
cdf.push_back(0); cdf.push_back(0);
for (Idx i = 0; i < prob_size; ++i) for (Idx i = 0; i < prob_size; ++i)
if (!used[i]) { if (!used[i]) {
N++; N++;
accum += prob_data[i]; accum += prob_data[i];
if (!replace) if (!replace) id_mapping.push_back(i);
id_mapping.push_back(i);
cdf.push_back(accum); cdf.push_back(accum);
} }
if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'"; if (N == 0)
LOG(FATAL)
<< "Cannot take more sample than population when 'replace=false'";
} }
public: public:
void ResetState(FloatArray prob) { void ResetState(FloatArray prob) {
used.resize(prob->shape[0]); used.resize(prob->shape[0]);
if (!replace) if (!replace) _prob = prob;
_prob = prob;
std::fill(used.begin(), used.end(), false); std::fill(used.begin(), used.end(), false);
Reconstruct(prob); Reconstruct(prob);
} }
explicit CDFSampler(RandomEngine *re, FloatArray prob): re(re) { explicit CDFSampler(RandomEngine *re, FloatArray prob) : re(re) {
ResetState(prob); ResetState(prob);
} }
...@@ -233,13 +228,12 @@ class CDFSampler: public BaseSampler<Idx> { ...@@ -233,13 +228,12 @@ class CDFSampler: public BaseSampler<Idx> {
double eps = std::numeric_limits<double>::min(); double eps = std::numeric_limits<double>::min();
if (!replace) { if (!replace) {
const DType *_prob_data = _prob.Ptr<DType>(); const DType *_prob_data = _prob.Ptr<DType>();
if (2 * taken >= accum) if (2 * taken >= accum) Reconstruct(_prob);
Reconstruct(_prob); if (accum <= 0) return -1;
if (accum <= 0)
return -1;
while (true) { while (true) {
double p = std::max(re->Uniform<double>(0., accum), eps); double p = std::max(re->Uniform<double>(0., accum), eps);
Idx rst = Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1); Idx rst =
Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
double cap = static_cast<double>(_prob_data[rst]); double cap = static_cast<double>(_prob_data[rst]);
if (!used[rst]) { if (!used[rst]) {
used[rst] = true; used[rst] = true;
...@@ -248,29 +242,24 @@ class CDFSampler: public BaseSampler<Idx> { ...@@ -248,29 +242,24 @@ class CDFSampler: public BaseSampler<Idx> {
} }
} }
} }
if (accum <= 0) if (accum <= 0) return -1;
return -1;
double p = std::max(re->Uniform<double>(0., accum), eps); double p = std::max(re->Uniform<double>(0., accum), eps);
return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1); return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
} }
}; };
/** /**
* TreeSampler is used to sample elements from a given discrete categorical distribution. * TreeSampler is used to sample elements from a given discrete categorical
* Algorithm: create a heap that stores accumulated likelihood of its leaf descendents. * distribution. Algorithm: create a heap that stores accumulated likelihood of
* Reference: https://blog.smola.org/post/1016514759 * its leaf descendents. Reference: https://blog.smola.org/post/1016514759
* Sampler building complexity: O(n) * Sampler building complexity: O(n)
* Sample w/ and w/o replacement complexity: O(log n) * Sample w/ and w/o replacement complexity: O(log n)
*/ */
template < template <typename Idx, typename DType, bool replace>
typename Idx, class TreeSampler : public BaseSampler<Idx> {
typename DType,
bool replace>
class TreeSampler: public BaseSampler<Idx> {
private: private:
RandomEngine *re; RandomEngine *re;
std::vector<double> weight; // accumulated likelihood of subtrees. std::vector<double> weight; // accumulated likelihood of subtrees.
int64_t N; int64_t N;
int64_t num_leafs; int64_t num_leafs;
const DType *decrease; const DType *decrease;
...@@ -286,11 +275,11 @@ class TreeSampler: public BaseSampler<Idx> { ...@@ -286,11 +275,11 @@ class TreeSampler: public BaseSampler<Idx> {
weight[i] = weight[i * 2] + weight[i * 2 + 1]; weight[i] = weight[i * 2] + weight[i * 2 + 1];
} }
explicit TreeSampler(RandomEngine *re, FloatArray prob, const DType* decrease = nullptr) explicit TreeSampler(
: re(re), decrease(decrease) { RandomEngine *re, FloatArray prob, const DType *decrease = nullptr)
: re(re), decrease(decrease) {
num_leafs = 1; num_leafs = 1;
while (num_leafs < prob->shape[0]) while (num_leafs < prob->shape[0]) num_leafs *= 2;
num_leafs *= 2;
N = num_leafs * 2; N = num_leafs * 2;
weight.resize(N); weight.resize(N);
ResetState(prob); ResetState(prob);
...@@ -298,18 +287,18 @@ class TreeSampler: public BaseSampler<Idx> { ...@@ -298,18 +287,18 @@ class TreeSampler: public BaseSampler<Idx> {
/* Pick an element from the given distribution and update the tree. /* Pick an element from the given distribution and update the tree.
* *
* The parameter decrease is an array of which the length is the number of categories. * The parameter decrease is an array of which the length is the number of
* Every time an element in the category x is picked, the weight of this category is subtracted * categories. Every time an element in the category x is picked, the weight
* by decrease[x]. It is used to support the case where a category might contains multiple * of this category is subtracted by decrease[x]. It is used to support the
* candidates and decrease[x] is the weight of one candidate of the category x. * case where a category might contains multiple candidates and decrease[x] is
* the weight of one candidate of the category x.
* *
* When decrease == nullptr, it means there is only one candidate in each category and will * When decrease == nullptr, it means there is only one candidate in each
* directly set the weight of the chosen category as 0. * category and will directly set the weight of the chosen category as 0.
* *
*/ */
Idx Draw() { Idx Draw() {
if (weight[1] <= 0) if (weight[1] <= 0) return -1;
return -1;
int64_t cur = 1; int64_t cur = 1;
double p = re->Uniform<double>(0, weight[cur]); double p = re->Uniform<double>(0, weight[cur]);
double accum = 0.; double accum = 0.;
...@@ -319,15 +308,16 @@ class TreeSampler: public BaseSampler<Idx> { ...@@ -319,15 +308,16 @@ class TreeSampler: public BaseSampler<Idx> {
// w_r > 0 can suppress some numerical problems. // w_r > 0 can suppress some numerical problems.
Idx shift = static_cast<Idx>(p > pivot && w_r > 0); Idx shift = static_cast<Idx>(p > pivot && w_r > 0);
cur = cur * 2 + shift; cur = cur * 2 + shift;
if (shift == 1) if (shift == 1) accum = pivot;
accum = pivot;
} }
Idx rst = cur - num_leafs; Idx rst = cur - num_leafs;
if (!replace) { if (!replace) {
while (cur >= 1) { while (cur >= 1) {
if (cur >= num_leafs) if (cur >= num_leafs)
weight[cur] = this->decrease ? weight[cur] =
weight[cur] - static_cast<double>(this->decrease[rst]) : 0.; this->decrease
? weight[cur] - static_cast<double>(this->decrease[rst])
: 0.;
else else
weight[cur] = weight[cur * 2] + weight[cur * 2 + 1]; weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];
cur /= 2; cur /= 2;
......
...@@ -3,14 +3,15 @@ ...@@ -3,14 +3,15 @@
* @file ndarray.cc * @file ndarray.cc
* @brief NDArray container infratructure. * @brief NDArray container infratructure.
*/ */
#include <string.h>
#include <dmlc/logging.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/shared_mem.h> #include <dgl/runtime/shared_mem.h>
#include <dgl/zerocopy_serializer.h>
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/logging.h>
#include <string.h>
#include "runtime_base.h" #include "runtime_base.h"
namespace dgl { namespace dgl {
...@@ -66,16 +67,15 @@ void NDArray::Internal::DefaultDeleter(NDArray::Container* ptr) { ...@@ -66,16 +67,15 @@ void NDArray::Internal::DefaultDeleter(NDArray::Container* ptr) {
ptr->mem = nullptr; ptr->mem = nullptr;
} else if (ptr->dl_tensor.data != nullptr) { } else if (ptr->dl_tensor.data != nullptr) {
// if the array is still pinned before freeing, unpin it. // if the array is still pinned before freeing, unpin it.
if (ptr->pinned_by_dgl_) if (ptr->pinned_by_dgl_) UnpinContainer(ptr);
UnpinContainer(ptr); dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace( ->FreeDataSpace(ptr->dl_tensor.ctx, ptr->dl_tensor.data);
ptr->dl_tensor.ctx, ptr->dl_tensor.data);
} }
delete ptr; delete ptr;
} }
NDArray NDArray::Internal::Create(std::vector<int64_t> shape, NDArray NDArray::Internal::Create(
DGLDataType dtype, DGLContext ctx) { std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {
VerifyDataType(dtype); VerifyDataType(dtype);
// critical zone // critical zone
NDArray::Container* data = new NDArray::Container(); NDArray::Container* data = new NDArray::Container();
...@@ -91,7 +91,7 @@ NDArray NDArray::Internal::Create(std::vector<int64_t> shape, ...@@ -91,7 +91,7 @@ NDArray NDArray::Internal::Create(std::vector<int64_t> shape,
// does not support NULL stride and thus will crash the program). // does not support NULL stride and thus will crash the program).
data->stride_.resize(data->dl_tensor.ndim, 1); data->stride_.resize(data->dl_tensor.ndim, 1);
for (int i = data->dl_tensor.ndim - 2; i >= 0; --i) { for (int i = data->dl_tensor.ndim - 2; i >= 0; --i) {
data->stride_[i] = data->shape_[i+1] * data->stride_[i+1]; data->stride_[i] = data->shape_[i + 1] * data->stride_[i + 1];
} }
data->dl_tensor.strides = dmlc::BeginPtr(data->stride_); data->dl_tensor.strides = dmlc::BeginPtr(data->stride_);
// setup dtype // setup dtype
...@@ -108,13 +108,10 @@ DGLArray* NDArray::Internal::MoveAsDGLArray(NDArray arr) { ...@@ -108,13 +108,10 @@ DGLArray* NDArray::Internal::MoveAsDGLArray(NDArray arr) {
return tensor; return tensor;
} }
size_t NDArray::GetSize() const { size_t NDArray::GetSize() const { return GetDataSize(data_->dl_tensor); }
return GetDataSize(data_->dl_tensor);
}
int64_t NDArray::NumElements() const { int64_t NDArray::NumElements() const {
if (data_->dl_tensor.ndim == 0) if (data_->dl_tensor.ndim == 0) return 0;
return 0;
int64_t size = 1; int64_t size = 1;
for (int i = 0; i < data_->dl_tensor.ndim; ++i) { for (int i = 0; i < data_->dl_tensor.ndim; ++i) {
size *= data_->dl_tensor.shape[i]; size *= data_->dl_tensor.shape[i];
...@@ -124,10 +121,10 @@ int64_t NDArray::NumElements() const { ...@@ -124,10 +121,10 @@ int64_t NDArray::NumElements() const {
bool NDArray::IsContiguous() const { bool NDArray::IsContiguous() const {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
if (data_->dl_tensor.strides == nullptr) if (data_->dl_tensor.strides == nullptr) return true;
return true;
// See https://github.com/dmlc/dgl/issues/2118 and PyTorch's compute_contiguous() implementation // See https://github.com/dmlc/dgl/issues/2118 and PyTorch's
// compute_contiguous() implementation
int64_t z = 1; int64_t z = 1;
for (int64_t i = data_->dl_tensor.ndim - 1; i >= 0; --i) { for (int64_t i = data_->dl_tensor.ndim - 1; i >= 0; --i) {
if (data_->dl_tensor.shape[i] != 1) { if (data_->dl_tensor.shape[i] != 1) {
...@@ -140,14 +137,12 @@ bool NDArray::IsContiguous() const { ...@@ -140,14 +137,12 @@ bool NDArray::IsContiguous() const {
return true; return true;
} }
NDArray NDArray::CreateView(std::vector<int64_t> shape, NDArray NDArray::CreateView(
DGLDataType dtype, std::vector<int64_t> shape, DGLDataType dtype, int64_t offset) {
int64_t offset) {
CHECK(data_ != nullptr); CHECK(data_ != nullptr);
CHECK(IsContiguous()) << "Can only create view for compact tensor"; CHECK(IsContiguous()) << "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx); NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx);
ret.data_->dl_tensor.byte_offset = ret.data_->dl_tensor.byte_offset = this->data_->dl_tensor.byte_offset;
this->data_->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->data_->dl_tensor); size_t curr_size = GetDataSize(this->data_->dl_tensor);
size_t view_size = GetDataSize(ret.data_->dl_tensor); size_t view_size = GetDataSize(ret.data_->dl_tensor);
CHECK_LE(view_size, curr_size) CHECK_LE(view_size, curr_size)
...@@ -156,14 +151,13 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape, ...@@ -156,14 +151,13 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape,
this->data_->IncRef(); this->data_->IncRef();
ret.data_->manager_ctx = this->data_; ret.data_->manager_ctx = this->data_;
ret.data_->dl_tensor.data = ret.data_->dl_tensor.data =
static_cast<char*>(this->data_->dl_tensor.data) + offset; static_cast<char*>(this->data_->dl_tensor.data) + offset;
return ret; return ret;
} }
NDArray NDArray::EmptyShared(const std::string &name, NDArray NDArray::EmptyShared(
std::vector<int64_t> shape, const std::string& name, std::vector<int64_t> shape, DGLDataType dtype,
DGLDataType dtype, DGLContext ctx, bool is_create) {
DGLContext ctx, bool is_create) {
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.data_->dl_tensor);
...@@ -178,31 +172,28 @@ NDArray NDArray::EmptyShared(const std::string &name, ...@@ -178,31 +172,28 @@ NDArray NDArray::EmptyShared(const std::string &name,
return ret; return ret;
} }
NDArray NDArray::Empty(std::vector<int64_t> shape, NDArray NDArray::Empty(
DGLDataType dtype, std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {
DGLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor); size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
if (size > 0) if (size > 0)
ret.data_->dl_tensor.data = ret.data_->dl_tensor.data = DeviceAPI::Get(ret->ctx)->AllocDataSpace(
DeviceAPI::Get(ret->ctx)->AllocDataSpace( ret->ctx, size, alignment, ret->dtype);
ret->ctx, size, alignment, ret->dtype);
return ret; return ret;
} }
void NDArray::CopyFromTo(DGLArray* from, void NDArray::CopyFromTo(DGLArray* from, DGLArray* to) {
DGLArray* to) {
size_t from_size = GetDataSize(*from); size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to); size_t to_size = GetDataSize(*to);
CHECK_EQ(from_size, to_size) CHECK_EQ(from_size, to_size)
<< "DGLArrayCopyFromTo: The size must exactly match"; << "DGLArrayCopyFromTo: The size must exactly match";
CHECK(from->ctx.device_type == to->ctx.device_type CHECK(
|| from->ctx.device_type == kDGLCPU from->ctx.device_type == to->ctx.device_type ||
|| to->ctx.device_type == kDGLCPU) from->ctx.device_type == kDGLCPU || to->ctx.device_type == kDGLCPU)
<< "Can not copy across different ctx types directly"; << "Can not copy across different ctx types directly";
// Use the context that is *not* a cpu context to get the correct device // Use the context that is *not* a cpu context to get the correct device
// api manager. // api manager.
...@@ -210,16 +201,16 @@ void NDArray::CopyFromTo(DGLArray* from, ...@@ -210,16 +201,16 @@ void NDArray::CopyFromTo(DGLArray* from,
// default: local current cuda stream // default: local current cuda stream
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset), from->data, static_cast<size_t>(from->byte_offset), to->data,
to->data, static_cast<size_t>(to->byte_offset), static_cast<size_t>(to->byte_offset), from_size, from->ctx, to->ctx,
from_size, from->ctx, to->ctx, from->dtype); from->dtype);
} }
void NDArray::PinContainer(NDArray::Container* ptr) { void NDArray::PinContainer(NDArray::Container* ptr) {
if (IsContainerPinned(ptr)) return; if (IsContainerPinned(ptr)) return;
auto* tensor = &(ptr->dl_tensor); auto* tensor = &(ptr->dl_tensor);
CHECK_EQ(tensor->ctx.device_type, kDGLCPU) CHECK_EQ(tensor->ctx.device_type, kDGLCPU)
<< "Only NDArray on CPU can be pinned"; << "Only NDArray on CPU can be pinned";
DeviceAPI::Get(kDGLCUDA)->PinData(tensor->data, GetDataSize(*tensor)); DeviceAPI::Get(kDGLCUDA)->PinData(tensor->data, GetDataSize(*tensor));
ptr->pinned_by_dgl_ = true; ptr->pinned_by_dgl_ = true;
} }
...@@ -229,7 +220,7 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) { ...@@ -229,7 +220,7 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) {
// The tensor may be pinned outside of DGL via a different CUDA API, // The tensor may be pinned outside of DGL via a different CUDA API,
// so we cannot unpin it with cudaHostUnregister. // so we cannot unpin it with cudaHostUnregister.
CHECK(ptr->pinned_by_dgl_ || !container_is_pinned) CHECK(ptr->pinned_by_dgl_ || !container_is_pinned)
<< "Cannot unpin a tensor that is pinned outside of DGL."; << "Cannot unpin a tensor that is pinned outside of DGL.";
// 1. not pinned, do nothing // 1. not pinned, do nothing
if (!container_is_pinned) return; if (!container_is_pinned) return;
// 2. pinned by DGL, unpin it // 2. pinned by DGL, unpin it
...@@ -239,65 +230,61 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) { ...@@ -239,65 +230,61 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) {
void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) { void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
CHECK(td->IsAvailable()) << "RecordStream only works when TensorAdaptor is available."; CHECK(td->IsAvailable())
<< "RecordStream only works when TensorAdaptor is available.";
CHECK_EQ(tensor->ctx.device_type, kDGLCUDA) CHECK_EQ(tensor->ctx.device_type, kDGLCUDA)
<< "RecordStream only works with GPU tensors."; << "RecordStream only works with GPU tensors.";
td->RecordStream(tensor->data, stream, tensor->ctx.device_id); td->RecordStream(tensor->data, stream, tensor->ctx.device_id);
} }
template<typename T> template <typename T>
NDArray NDArray::FromVector(const std::vector<T>& vec, DGLContext ctx) { NDArray NDArray::FromVector(const std::vector<T>& vec, DGLContext ctx) {
const DGLDataType dtype = DGLDataTypeTraits<T>::dtype; const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;
int64_t size = static_cast<int64_t>(vec.size()); int64_t size = static_cast<int64_t>(vec.size());
NDArray ret = NDArray::Empty({size}, dtype, ctx); NDArray ret = NDArray::Empty({size}, dtype, ctx);
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
vec.data(), vec.data(), 0, static_cast<T*>(ret->data), 0, size * sizeof(T),
0, DGLContext{kDGLCPU, 0}, ctx, dtype);
static_cast<T*>(ret->data),
0,
size * sizeof(T),
DGLContext{kDGLCPU, 0},
ctx,
dtype);
return ret; return ret;
} }
NDArray NDArray::CreateFromRaw(const std::vector<int64_t>& shape, NDArray NDArray::CreateFromRaw(
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free) { const std::vector<int64_t>& shape, DGLDataType dtype, DGLContext ctx,
void* raw, bool auto_free) {
NDArray ret = Internal::Create(shape, dtype, ctx); NDArray ret = Internal::Create(shape, dtype, ctx);
ret.data_->dl_tensor.data = raw; ret.data_->dl_tensor.data = raw;
if (!auto_free) if (!auto_free) ret.data_->deleter = nullptr;
ret.data_->deleter = nullptr;
return ret; return ret;
} }
// export specializations // export specializations
template NDArray NDArray::FromVector<int32_t>(const std::vector<int32_t>&, DGLContext); template NDArray NDArray::FromVector<int32_t>(
template NDArray NDArray::FromVector<int64_t>(const std::vector<int64_t>&, DGLContext); const std::vector<int32_t>&, DGLContext);
template NDArray NDArray::FromVector<uint32_t>(const std::vector<uint32_t>&, DGLContext); template NDArray NDArray::FromVector<int64_t>(
template NDArray NDArray::FromVector<uint64_t>(const std::vector<uint64_t>&, DGLContext); const std::vector<int64_t>&, DGLContext);
template NDArray NDArray::FromVector<float>(const std::vector<float>&, DGLContext); template NDArray NDArray::FromVector<uint32_t>(
template NDArray NDArray::FromVector<double>(const std::vector<double>&, DGLContext); const std::vector<uint32_t>&, DGLContext);
template NDArray NDArray::FromVector<uint64_t>(
template<typename T> const std::vector<uint64_t>&, DGLContext);
template NDArray NDArray::FromVector<float>(
const std::vector<float>&, DGLContext);
template NDArray NDArray::FromVector<double>(
const std::vector<double>&, DGLContext);
template <typename T>
std::vector<T> NDArray::ToVector() const { std::vector<T> NDArray::ToVector() const {
const DGLDataType dtype = DGLDataTypeTraits<T>::dtype; const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;
CHECK(data_->dl_tensor.ndim == 1) << "ToVector() only supported for 1D arrays"; CHECK(data_->dl_tensor.ndim == 1)
<< "ToVector() only supported for 1D arrays";
CHECK(data_->dl_tensor.dtype == dtype) << "dtype mismatch"; CHECK(data_->dl_tensor.dtype == dtype) << "dtype mismatch";
int64_t size = data_->dl_tensor.shape[0]; int64_t size = data_->dl_tensor.shape[0];
std::vector<T> vec(size); std::vector<T> vec(size);
const DGLContext &ctx = data_->dl_tensor.ctx; const DGLContext& ctx = data_->dl_tensor.ctx;
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
static_cast<T*>(data_->dl_tensor.data), static_cast<T*>(data_->dl_tensor.data), 0, vec.data(), 0,
0, size * sizeof(T), ctx, DGLContext{kDGLCPU, 0}, dtype);
vec.data(),
0,
size * sizeof(T),
ctx,
DGLContext{kDGLCPU, 0},
dtype);
return vec; return vec;
} }
...@@ -313,13 +300,12 @@ std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const { ...@@ -313,13 +300,12 @@ std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {
} }
bool NDArray::IsContainerPinned(NDArray::Container* ptr) { bool NDArray::IsContainerPinned(NDArray::Container* ptr) {
if (ptr->pinned_by_dgl_) if (ptr->pinned_by_dgl_) return true;
return true;
auto* tensor = &(ptr->dl_tensor); auto* tensor = &(ptr->dl_tensor);
// Can only be pinned if on CPU... // Can only be pinned if on CPU...
if (tensor->ctx.device_type != kDGLCPU) if (tensor->ctx.device_type != kDGLCPU) return false;
return false; // ... and CUDA device API is enabled, and the tensor is indeed in pinned
// ... and CUDA device API is enabled, and the tensor is indeed in pinned memory. // memory.
auto device = DeviceAPI::Get(kDGLCUDA, true); auto device = DeviceAPI::Get(kDGLCUDA, true);
return device && device->IsPinned(tensor->data); return device && device->IsPinned(tensor->data);
} }
...@@ -340,27 +326,20 @@ bool NDArray::Load(dmlc::Stream* strm) { ...@@ -340,27 +326,20 @@ bool NDArray::Load(dmlc::Stream* strm) {
return true; return true;
} }
uint64_t header, reserved; uint64_t header, reserved;
CHECK(strm->Read(&header)) CHECK(strm->Read(&header)) << "Invalid DGLArray file format";
<< "Invalid DGLArray file format"; CHECK(strm->Read(&reserved)) << "Invalid DGLArray file format";
CHECK(strm->Read(&reserved)) CHECK(header == kDGLNDArrayMagic) << "Invalid DGLArray file format";
<< "Invalid DGLArray file format";
CHECK(header == kDGLNDArrayMagic)
<< "Invalid DGLArray file format";
DGLContext ctx; DGLContext ctx;
int ndim; int ndim;
DGLDataType dtype; DGLDataType dtype;
CHECK(strm->Read(&ctx)) CHECK(strm->Read(&ctx)) << "Invalid DGLArray file format";
<< "Invalid DGLArray file format"; CHECK(strm->Read(&ndim)) << "Invalid DGLArray file format";
CHECK(strm->Read(&ndim)) CHECK(strm->Read(&dtype)) << "Invalid DGLArray file format";
<< "Invalid DGLArray file format";
CHECK(strm->Read(&dtype))
<< "Invalid DGLArray file format";
CHECK_EQ(ctx.device_type, kDGLCPU) CHECK_EQ(ctx.device_type, kDGLCPU)
<< "Invalid DGLArray context: can only save as CPU tensor"; << "Invalid DGLArray context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim); std::vector<int64_t> shape(ndim);
if (ndim != 0) { if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim)) CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DGLArray file format";
<< "Invalid DGLArray file format";
} }
NDArray ret = NDArray::Empty(shape, dtype, ctx); NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1; int64_t num_elems = 1;
...@@ -369,11 +348,10 @@ bool NDArray::Load(dmlc::Stream* strm) { ...@@ -369,11 +348,10 @@ bool NDArray::Load(dmlc::Stream* strm) {
num_elems *= ret->shape[i]; num_elems *= ret->shape[i];
} }
int64_t data_byte_size; int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size)) CHECK(strm->Read(&data_byte_size)) << "Invalid DGLArray file format";
<< "Invalid DGLArray file format";
CHECK(data_byte_size == num_elems * elem_bytes) CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DGLArray file format"; << "Invalid DGLArray file format";
if (data_byte_size != 0) { if (data_byte_size != 0) {
// strm->Read will return the total number of elements successfully read. // strm->Read will return the total number of elements successfully read.
// Therefore if data_byte_size is zero, the CHECK below would fail. // Therefore if data_byte_size is zero, the CHECK below would fail.
CHECK(strm->Read(ret->data, data_byte_size)) CHECK(strm->Read(ret->data, data_byte_size))
...@@ -386,20 +364,14 @@ bool NDArray::Load(dmlc::Stream* strm) { ...@@ -386,20 +364,14 @@ bool NDArray::Load(dmlc::Stream* strm) {
return true; return true;
} }
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
using namespace dgl::runtime; using namespace dgl::runtime;
int DGLArrayAlloc(const dgl_index_t* shape, int DGLArrayAlloc(
int ndim, const dgl_index_t* shape, int ndim, int dtype_code, int dtype_bits,
int dtype_code, int dtype_lanes, int device_type, int device_id, DGLArrayHandle* out) {
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
DGLArrayHandle* out) {
API_BEGIN(); API_BEGIN();
DGLDataType dtype; DGLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code); dtype.code = static_cast<uint8_t>(dtype_code);
...@@ -413,22 +385,17 @@ int DGLArrayAlloc(const dgl_index_t* shape, ...@@ -413,22 +385,17 @@ int DGLArrayAlloc(const dgl_index_t* shape,
API_END(); API_END();
} }
int DGLArrayAllocSharedMem(const char *mem_name, int DGLArrayAllocSharedMem(
const dgl_index_t *shape, const char* mem_name, const dgl_index_t* shape, int ndim, int dtype_code,
int ndim, int dtype_bits, int dtype_lanes, bool is_create, DGLArrayHandle* out) {
int dtype_code,
int dtype_bits,
int dtype_lanes,
bool is_create,
DGLArrayHandle* out) {
API_BEGIN(); API_BEGIN();
DGLDataType dtype; DGLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code); dtype.code = static_cast<uint8_t>(dtype_code);
dtype.bits = static_cast<uint8_t>(dtype_bits); dtype.bits = static_cast<uint8_t>(dtype_bits);
dtype.lanes = static_cast<uint16_t>(dtype_lanes); dtype.lanes = static_cast<uint16_t>(dtype_lanes);
std::vector<int64_t> shape_vec(shape, shape + ndim); std::vector<int64_t> shape_vec(shape, shape + ndim);
NDArray arr = NDArray::EmptyShared(mem_name, shape_vec, dtype, NDArray arr = NDArray::EmptyShared(
DGLContext{kDGLCPU, 0}, is_create); mem_name, shape_vec, dtype, DGLContext{kDGLCPU, 0}, is_create);
*out = NDArray::Internal::MoveAsDGLArray(arr); *out = NDArray::Internal::MoveAsDGLArray(arr);
API_END(); API_END();
} }
...@@ -439,57 +406,48 @@ int DGLArrayFree(DGLArrayHandle handle) { ...@@ -439,57 +406,48 @@ int DGLArrayFree(DGLArrayHandle handle) {
API_END(); API_END();
} }
int DGLArrayCopyFromTo(DGLArrayHandle from, int DGLArrayCopyFromTo(DGLArrayHandle from, DGLArrayHandle to) {
DGLArrayHandle to) {
API_BEGIN(); API_BEGIN();
NDArray::CopyFromTo(from, to); NDArray::CopyFromTo(from, to);
API_END(); API_END();
} }
int DGLArrayCopyFromBytes(DGLArrayHandle handle, int DGLArrayCopyFromBytes(DGLArrayHandle handle, void* data, size_t nbytes) {
void* data,
size_t nbytes) {
API_BEGIN(); API_BEGIN();
DGLContext cpu_ctx; DGLContext cpu_ctx;
cpu_ctx.device_type = kDGLCPU; cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle); size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes) CHECK_EQ(arr_size, nbytes) << "DGLArrayCopyFromBytes: size mismatch";
<< "DGLArrayCopyFromBytes: size mismatch"; DeviceAPI::Get(handle->ctx)
DeviceAPI::Get(handle->ctx)->CopyDataFromTo( ->CopyDataFromTo(
data, 0, data, 0, handle->data, static_cast<size_t>(handle->byte_offset),
handle->data, static_cast<size_t>(handle->byte_offset), nbytes, cpu_ctx, handle->ctx, handle->dtype);
nbytes, cpu_ctx, handle->ctx, handle->dtype);
API_END(); API_END();
} }
int DGLArrayCopyToBytes(DGLArrayHandle handle, int DGLArrayCopyToBytes(DGLArrayHandle handle, void* data, size_t nbytes) {
void* data,
size_t nbytes) {
API_BEGIN(); API_BEGIN();
DGLContext cpu_ctx; DGLContext cpu_ctx;
cpu_ctx.device_type = kDGLCPU; cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle); size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes) CHECK_EQ(arr_size, nbytes) << "DGLArrayCopyToBytes: size mismatch";
<< "DGLArrayCopyToBytes: size mismatch"; DeviceAPI::Get(handle->ctx)
DeviceAPI::Get(handle->ctx)->CopyDataFromTo( ->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset), handle->data, static_cast<size_t>(handle->byte_offset), data, 0,
data, 0, nbytes, handle->ctx, cpu_ctx, handle->dtype);
nbytes, handle->ctx, cpu_ctx, handle->dtype);
API_END(); API_END();
} }
int DGLArrayPinData(DGLArrayHandle handle, int DGLArrayPinData(DGLArrayHandle handle, DGLContext ctx) {
DGLContext ctx) {
API_BEGIN(); API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle); auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::PinContainer(nd_container); NDArray::PinContainer(nd_container);
API_END(); API_END();
} }
int DGLArrayUnpinData(DGLArrayHandle handle, int DGLArrayUnpinData(DGLArrayHandle handle, DGLContext ctx) {
DGLContext ctx) {
API_BEGIN(); API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle); auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::UnpinContainer(nd_container); NDArray::UnpinContainer(nd_container);
......
#include <gtest/gtest.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <tuple> #include <gtest/gtest.h>
#include <set> #include <set>
#include <tuple>
#include "./common.h" #include "./common.h"
using namespace dgl; using namespace dgl;
...@@ -62,7 +64,7 @@ std::set<ETuple<Idx>> ToEdgeSet(COOMatrix mat) { ...@@ -62,7 +64,7 @@ std::set<ETuple<Idx>> ToEdgeSet(COOMatrix mat) {
Idx* col = static_cast<Idx*>(mat.col->data); Idx* col = static_cast<Idx*>(mat.col->data);
Idx* data = static_cast<Idx*>(mat.data->data); Idx* data = static_cast<Idx*>(mat.data->data);
for (int64_t i = 0; i < mat.row->shape[0]; ++i) { for (int64_t i = 0; i < mat.row->shape[0]; ++i) {
//std::cout << row[i] << " " << col[i] << " " << data[i] << std::endl; // std::cout << row[i] << " " << col[i] << " " << data[i] << std::endl;
eset.emplace(row[i], col[i], data[i]); eset.emplace(row[i], col[i], data[i]);
} }
return eset; return eset;
...@@ -122,7 +124,8 @@ COOMatrix COO(bool has_data) { ...@@ -122,7 +124,8 @@ COOMatrix COO(bool has_data) {
template <typename Idx> template <typename Idx>
std::pair<CSRMatrix, std::vector<int64_t>> CSREtypes(bool has_data) { std::pair<CSRMatrix, std::vector<int64_t>> CSREtypes(bool has_data) {
IdArray indptr = NDArray::FromVector(std::vector<Idx>({0, 4, 5, 5, 7})); IdArray indptr = NDArray::FromVector(std::vector<Idx>({0, 4, 5, 5, 7}));
IdArray indices = NDArray::FromVector(std::vector<Idx>({0, 1, 2, 3, 1, 3, 2})); IdArray indices =
NDArray::FromVector(std::vector<Idx>({0, 1, 2, 3, 1, 3, 2}));
IdArray data = NDArray::FromVector(std::vector<Idx>({0, 1, 4, 6, 2, 3, 5})); IdArray data = NDArray::FromVector(std::vector<Idx>({0, 1, 4, 6, 2, 3, 5}));
auto eid2etype_offsets = std::vector<int64_t>({0, 4, 5, 6, 7}); auto eid2etype_offsets = std::vector<int64_t>({0, 4, 5, 6, 7});
if (has_data) if (has_data)
...@@ -146,8 +149,8 @@ std::pair<COOMatrix, std::vector<int64_t>> COOEtypes(bool has_data) { ...@@ -146,8 +149,8 @@ std::pair<COOMatrix, std::vector<int64_t>> COOEtypes(bool has_data) {
template <typename Idx, typename FloatType> template <typename Idx, typename FloatType>
void _TestCSRSampling(bool has_data) { void _TestCSRSampling(bool has_data) {
auto mat = CSR<Idx>(has_data); auto mat = CSR<Idx>(has_data);
FloatArray prob = NDArray::FromVector( FloatArray prob =
std::vector<FloatType>({.5, .5, .5, .5, .5})); NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true); auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);
...@@ -170,8 +173,7 @@ void _TestCSRSampling(bool has_data) { ...@@ -170,8 +173,7 @@ void _TestCSRSampling(bool has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} }
} }
prob = NDArray::FromVector( prob = NDArray::FromVector(std::vector<FloatType>({.0, .5, .5, .0, .5}));
std::vector<FloatType>({.0, .5, .5, .0, .5}));
for (int k = 0; k < 100; ++k) { for (int k = 0; k < 100; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true); auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data); CheckSampledResult<Idx>(rst, rows, has_data);
...@@ -240,18 +242,19 @@ void _TestCSRPerEtypeSampling(bool has_data) { ...@@ -240,18 +242,19 @@ void _TestCSRPerEtypeSampling(bool has_data) {
auto mat = pair.first; auto mat = pair.first;
auto eid2etype_offset = pair.second; auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = { std::vector<FloatArray> prob = {
NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})), NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})) NDArray::FromVector(std::vector<FloatType>({.5}))};
};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -294,13 +297,13 @@ void _TestCSRPerEtypeSampling(bool has_data) { ...@@ -294,13 +297,13 @@ void _TestCSRPerEtypeSampling(bool has_data) {
} }
prob = { prob = {
NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})), NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})) NDArray::FromVector(std::vector<FloatType>({.5}))};
};
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -319,18 +322,19 @@ void _TestCSRPerEtypeSamplingSorted() { ...@@ -319,18 +322,19 @@ void _TestCSRPerEtypeSamplingSorted() {
auto mat = pair.first; auto mat = pair.first;
auto eid2etype_offset = pair.second; auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = { std::vector<FloatArray> prob = {
NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})), NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})) NDArray::FromVector(std::vector<FloatType>({.5}))};
};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true); CheckSampledPerEtypeResult<Idx>(rst, rows, true);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true); CheckSampledPerEtypeResult<Idx>(rst, rows, true);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
int counts = 0; int counts = 0;
...@@ -355,13 +359,13 @@ void _TestCSRPerEtypeSamplingSorted() { ...@@ -355,13 +359,13 @@ void _TestCSRPerEtypeSamplingSorted() {
} }
prob = { prob = {
NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})), NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})) NDArray::FromVector(std::vector<FloatType>({.5}))};
};
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true); CheckSampledPerEtypeResult<Idx>(rst, rows, true);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));
...@@ -389,18 +393,17 @@ void _TestCSRPerEtypeSamplingUniform(bool has_data) { ...@@ -389,18 +393,17 @@ void _TestCSRPerEtypeSamplingUniform(bool has_data) {
auto mat = pair.first; auto mat = pair.first;
auto eid2etype_offset = pair.second; auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = { std::vector<FloatArray> prob = {
aten::NullArray(), aten::NullArray(), aten::NullArray(), aten::NullArray(),
aten::NullArray(), aten::NullArray()};
aten::NullArray(),
aten::NullArray()
};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -449,18 +452,17 @@ void _TestCSRPerEtypeSamplingUniformSorted() { ...@@ -449,18 +452,17 @@ void _TestCSRPerEtypeSamplingUniformSorted() {
auto mat = pair.first; auto mat = pair.first;
auto eid2etype_offset = pair.second; auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = { std::vector<FloatArray> prob = {
aten::NullArray(), aten::NullArray(), aten::NullArray(), aten::NullArray(),
aten::NullArray(), aten::NullArray()};
aten::NullArray(),
aten::NullArray()
};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true); CheckSampledPerEtypeResult<Idx>(rst, rows, true);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true); auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true); CheckSampledPerEtypeResult<Idx>(rst, rows, true);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
int counts = 0; int counts = 0;
...@@ -500,12 +502,11 @@ TEST(RowwiseTest, TestCSRPerEtypeSamplingUniform) { ...@@ -500,12 +502,11 @@ TEST(RowwiseTest, TestCSRPerEtypeSamplingUniform) {
_TestCSRPerEtypeSamplingUniformSorted<int64_t, double>(); _TestCSRPerEtypeSamplingUniformSorted<int64_t, double>();
} }
template <typename Idx, typename FloatType> template <typename Idx, typename FloatType>
void _TestCOOSampling(bool has_data) { void _TestCOOSampling(bool has_data) {
auto mat = COO<Idx>(has_data); auto mat = COO<Idx>(has_data);
FloatArray prob = NDArray::FromVector( FloatArray prob =
std::vector<FloatType>({.5, .5, .5, .5, .5})); NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, true); auto rst = COORowWiseSampling(mat, rows, 2, prob, true);
...@@ -528,8 +529,7 @@ void _TestCOOSampling(bool has_data) { ...@@ -528,8 +529,7 @@ void _TestCOOSampling(bool has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} }
} }
prob = NDArray::FromVector( prob = NDArray::FromVector(std::vector<FloatType>({.0, .5, .5, .0, .5}));
std::vector<FloatType>({.0, .5, .5, .0, .5}));
for (int k = 0; k < 100; ++k) { for (int k = 0; k < 100; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, true); auto rst = COORowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data); CheckSampledResult<Idx>(rst, rows, has_data);
...@@ -601,18 +601,19 @@ void _TestCOOPerEtypeSampling(bool has_data) { ...@@ -601,18 +601,19 @@ void _TestCOOPerEtypeSampling(bool has_data) {
auto mat = pair.first; auto mat = pair.first;
auto eid2etype_offset = pair.second; auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = { std::vector<FloatArray> prob = {
NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})), NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})) NDArray::FromVector(std::vector<FloatType>({.5}))};
};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false); auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -655,13 +656,13 @@ void _TestCOOPerEtypeSampling(bool has_data) { ...@@ -655,13 +656,13 @@ void _TestCOOPerEtypeSampling(bool has_data) {
} }
prob = { prob = {
NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})), NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})), NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})) NDArray::FromVector(std::vector<FloatType>({.5}))};
};
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -691,18 +692,17 @@ void _TestCOOPerEtypeSamplingUniform(bool has_data) { ...@@ -691,18 +692,17 @@ void _TestCOOPerEtypeSamplingUniform(bool has_data) {
auto mat = pair.first; auto mat = pair.first;
auto eid2etype_offset = pair.second; auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = { std::vector<FloatArray> prob = {
aten::NullArray(), aten::NullArray(), aten::NullArray(), aten::NullArray(),
aten::NullArray(), aten::NullArray()};
aten::NullArray(),
aten::NullArray()
};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true); auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
} }
for (int k = 0; k < 10; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false); auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data); CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
if (has_data) { if (has_data) {
...@@ -759,35 +759,35 @@ TEST(RowwiseTest, TestCOOPerEtypeSamplingUniform) { ...@@ -759,35 +759,35 @@ TEST(RowwiseTest, TestCOOPerEtypeSamplingUniform) {
template <typename Idx, typename FloatType> template <typename Idx, typename FloatType>
void _TestCSRTopk(bool has_data) { void _TestCSRTopk(bool has_data) {
auto mat = CSR<Idx>(has_data); auto mat = CSR<Idx>(has_data);
FloatArray weight = NDArray::FromVector( FloatArray weight =
std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f})); NDArray::FromVector(std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f}));
// -.1, .2, .1, .0, .5 // -.1, .2, .1, .0, .5
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
{ {
auto rst = CSRRowWiseTopk(mat, rows, 1, weight, true); auto rst = CSRRowWiseTopk(mat, rows, 1, weight, true);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 2); ASSERT_EQ(eset.size(), 2);
if (has_data) { if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));
} else { } else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));
} }
} }
{ {
auto rst = CSRRowWiseTopk(mat, rows, 1, weight, false); auto rst = CSRRowWiseTopk(mat, rows, 1, weight, false);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 2); ASSERT_EQ(eset.size(), 2);
if (has_data) { if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else { } else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} }
} }
} }
...@@ -802,39 +802,38 @@ TEST(RowwiseTest, TestCSRTopk) { ...@@ -802,39 +802,38 @@ TEST(RowwiseTest, TestCSRTopk) {
_TestCSRTopk<int64_t, double>(false); _TestCSRTopk<int64_t, double>(false);
} }
template <typename Idx, typename FloatType> template <typename Idx, typename FloatType>
void _TestCOOTopk(bool has_data) { void _TestCOOTopk(bool has_data) {
auto mat = COO<Idx>(has_data); auto mat = COO<Idx>(has_data);
FloatArray weight = NDArray::FromVector( FloatArray weight =
std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f})); NDArray::FromVector(std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f}));
// -.1, .2, .1, .0, .5 // -.1, .2, .1, .0, .5
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
{ {
auto rst = COORowWiseTopk(mat, rows, 1, weight, true); auto rst = COORowWiseTopk(mat, rows, 1, weight, true);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 2); ASSERT_EQ(eset.size(), 2);
if (has_data) { if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2))); ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));
} else { } else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3))); ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));
} }
} }
{ {
auto rst = COORowWiseTopk(mat, rows, 1, weight, false); auto rst = COORowWiseTopk(mat, rows, 1, weight, false);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 2); ASSERT_EQ(eset.size(), 2);
if (has_data) { if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3))); ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else { } else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0))); ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} }
} }
} }
...@@ -856,16 +855,11 @@ void _TestCSRSamplingBiased(bool has_data) { ...@@ -856,16 +855,11 @@ void _TestCSRSamplingBiased(bool has_data) {
// 1 - 1 // 1 - 1
// 3 - 2,3 // 3 - 2,3
NDArray tag_offset = NDArray::FromVector( NDArray tag_offset = NDArray::FromVector(
std::vector<Idx>({0, 1, 2, std::vector<Idx>({0, 1, 2, 0, 0, 1, 0, 0, 0, 0, 1, 2}));
0, 0, 1,
0, 0, 0,
0, 1, 2}));
tag_offset = tag_offset.CreateView({4, 3}, tag_offset->dtype); tag_offset = tag_offset.CreateView({4, 3}, tag_offset->dtype);
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 1, 3})); IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 1, 3}));
FloatArray bias = NDArray::FromVector( FloatArray bias = NDArray::FromVector(std::vector<FloatType>({0, 0.5}));
std::vector<FloatType>({0, 0.5}) for (int k = 0; k < 10; ++k) {
);
for (int k = 0 ; k < 10 ; ++k) {
auto rst = CSRRowWiseSamplingBiased(mat, rows, 1, tag_offset, bias, false); auto rst = CSRRowWiseSamplingBiased(mat, rows, 1, tag_offset, bias, false);
CheckSampledResult<Idx>(rst, rows, has_data); CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
...@@ -879,7 +873,7 @@ void _TestCSRSamplingBiased(bool has_data) { ...@@ -879,7 +873,7 @@ void _TestCSRSamplingBiased(bool has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4))); ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} }
} }
for (int k = 0 ; k < 10 ; ++k) { for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSamplingBiased(mat, rows, 3, tag_offset, bias, true); auto rst = CSRRowWiseSamplingBiased(mat, rows, 3, tag_offset, bias, true);
CheckSampledResult<Idx>(rst, rows, has_data); CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst); auto eset = ToEdgeSet<Idx>(rst);
......
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include "./common.h" #include <vector>
#include "../../src/random/cpu/sample_utils.h" #include "../../src/random/cpu/sample_utils.h"
#include "./common.h"
using namespace dgl; using namespace dgl;
using namespace dgl::aten; using namespace dgl::aten;
...@@ -11,7 +13,7 @@ using namespace dgl::aten; ...@@ -11,7 +13,7 @@ using namespace dgl::aten;
// TODO: adapt this to Random::Choice // TODO: adapt this to Random::Choice
template <typename Idx, typename DType> template <typename Idx, typename DType>
void _TestWithReplacement(RandomEngine *re) { void _TestWithReplacement(RandomEngine* re) {
Idx n_categories = 100; Idx n_categories = 100;
Idx n_rolls = 1000000; Idx n_rolls = 1000000;
std::vector<DType> _prob; std::vector<DType> _prob;
...@@ -20,12 +22,11 @@ void _TestWithReplacement(RandomEngine *re) { ...@@ -20,12 +22,11 @@ void _TestWithReplacement(RandomEngine *re) {
_prob.push_back(re->Uniform<DType>()); _prob.push_back(re->Uniform<DType>());
accum += _prob.back(); accum += _prob.back();
} }
for (Idx i = 0; i < n_categories; ++i) for (Idx i = 0; i < n_categories; ++i) _prob[i] /= accum;
_prob[i] /= accum;
FloatArray prob = NDArray::FromVector(_prob); FloatArray prob = NDArray::FromVector(_prob);
auto _check_given_sampler = [n_categories, n_rolls, &_prob]( auto _check_given_sampler = [n_categories, n_rolls,
utils::BaseSampler<Idx> *s) { &_prob](utils::BaseSampler<Idx>* s) {
std::vector<Idx> counter(n_categories, 0); std::vector<Idx> counter(n_categories, 0);
for (Idx i = 0; i < n_rolls; ++i) { for (Idx i = 0; i < n_rolls; ++i) {
Idx dice = s->Draw(); Idx dice = s->Draw();
...@@ -67,14 +68,13 @@ TEST(SampleUtilsTest, TestWithReplacement) { ...@@ -67,14 +68,13 @@ TEST(SampleUtilsTest, TestWithReplacement) {
}; };
template <typename Idx, typename DType> template <typename Idx, typename DType>
void _TestWithoutReplacementOrder(RandomEngine *re) { void _TestWithoutReplacementOrder(RandomEngine* re) {
// TODO(BarclayII): is there a reliable way to do this test? // TODO(BarclayII): is there a reliable way to do this test?
std::vector<DType> _prob = {1e6f, 1e-6f, 1e-2f, 1e2f}; std::vector<DType> _prob = {1e6f, 1e-6f, 1e-2f, 1e2f};
FloatArray prob = NDArray::FromVector(_prob); FloatArray prob = NDArray::FromVector(_prob);
std::vector<Idx> ground_truth = {0, 3, 2, 1}; std::vector<Idx> ground_truth = {0, 3, 2, 1};
auto _check_given_sampler = [&ground_truth]( auto _check_given_sampler = [&ground_truth](utils::BaseSampler<Idx>* s) {
utils::BaseSampler<Idx> *s) {
for (size_t i = 0; i < ground_truth.size(); ++i) { for (size_t i = 0; i < ground_truth.size(); ++i) {
Idx dice = s->Draw(); Idx dice = s->Draw();
ASSERT_EQ(dice, ground_truth[i]); ASSERT_EQ(dice, ground_truth[i]);
...@@ -102,22 +102,19 @@ TEST(SampleUtilsTest, TestWithoutReplacementOrder) { ...@@ -102,22 +102,19 @@ TEST(SampleUtilsTest, TestWithoutReplacementOrder) {
}; };
template <typename Idx, typename DType> template <typename Idx, typename DType>
void _TestWithoutReplacementUnique(RandomEngine *re) { void _TestWithoutReplacementUnique(RandomEngine* re) {
Idx N = 1000000; Idx N = 1000000;
std::vector<DType> _likelihood; std::vector<DType> _likelihood;
for (Idx i = 0; i < N; ++i) for (Idx i = 0; i < N; ++i) _likelihood.push_back(re->Uniform<DType>());
_likelihood.push_back(re->Uniform<DType>());
FloatArray likelihood = NDArray::FromVector(_likelihood); FloatArray likelihood = NDArray::FromVector(_likelihood);
auto _check_given_sampler = [N]( auto _check_given_sampler = [N](utils::BaseSampler<Idx>* s) {
utils::BaseSampler<Idx> *s) {
std::vector<int> cnt(N, 0); std::vector<int> cnt(N, 0);
for (Idx i = 0; i < N; ++i) { for (Idx i = 0; i < N; ++i) {
Idx dice = s->Draw(); Idx dice = s->Draw();
cnt[dice]++; cnt[dice]++;
} }
for (Idx i = 0; i < N; ++i) for (Idx i = 0; i < N; ++i) ASSERT_EQ(cnt[i], 1);
ASSERT_EQ(cnt[i], 1);
}; };
utils::AliasSampler<Idx, DType, false> as(re, likelihood); utils::AliasSampler<Idx, DType, false> as(re, likelihood);
...@@ -242,15 +239,15 @@ void _TestBiasedChoice(RandomEngine* re) { ...@@ -242,15 +239,15 @@ void _TestBiasedChoice(RandomEngine* re) {
{ {
Idx sample_num = 100000; Idx sample_num = 100000;
Idx population = 1000000; Idx population = 1000000;
Idx split[] = {0, population/2, population}; Idx split[] = {0, population / 2, population};
FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 3})); FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 3}));
IdArray rst = re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, true); IdArray rst =
auto rst_data = static_cast<Idx *>(rst->data); re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, true);
auto rst_data = static_cast<Idx*>(rst->data);
Idx larger = 0; Idx larger = 0;
for (Idx i = 0 ; i < sample_num ; ++i) for (Idx i = 0; i < sample_num; ++i)
if (rst_data[i] >= population / 2) if (rst_data[i] >= population / 2) larger++;
larger++;
ASSERT_LE(fabs((double)larger / sample_num - 0.75), 1e-2); ASSERT_LE(fabs((double)larger / sample_num - 0.75), 1e-2);
} }
// without replacement // without replacement
...@@ -260,8 +257,9 @@ void _TestBiasedChoice(RandomEngine* re) { ...@@ -260,8 +257,9 @@ void _TestBiasedChoice(RandomEngine* re) {
Idx split[] = {0, sample_num, population}; Idx split[] = {0, sample_num, population};
FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 0})); FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 0}));
IdArray rst = re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, false); IdArray rst =
auto rst_data = static_cast<Idx *>(rst->data); re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, false);
auto rst_data = static_cast<Idx*>(rst->data);
std::set<Idx> idxset; std::set<Idx> idxset;
for (int64_t i = 0; i < sample_num; ++i) { for (int64_t i = 0; i < sample_num; ++i) {
......
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "../../src/graph/heterograph.h" #include "../../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h" #include "../../src/graph/unit_graph.h"
#include "./common.h" #include "./common.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment