Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
......@@ -6,13 +6,13 @@
#ifndef DGL_ARRAY_KERNEL_DECL_H_
#define DGL_ARRAY_KERNEL_DECL_H_
#include <dgl/bcast.h>
#include <dgl/base_heterograph.h>
#include <dgl/bcast.h>
#include <dgl/runtime/ndarray.h>
#include <string>
#include <vector>
#include <utility>
#include <vector>
namespace dgl {
namespace aten {
......@@ -21,25 +21,20 @@ namespace aten {
* @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out,
void SpMMCsr(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const aten::CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
/**
* @brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format
with heterograph support.
* with heterograph support.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat,
std::vector<NDArray>* out,
void SpMMCsrHetero(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_eid,
const std::vector<dgl_type_t>& out_eid);
......@@ -47,144 +42,108 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
* @brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const aten::COOMatrix& coo,
NDArray ufeat,
NDArray efeat,
NDArray out,
void SpMMCoo(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const aten::COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
/**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target);
void SDDMMCsr(
const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
/**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr
format with heterograph support.
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format
* with heterograph support.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& ufeat_eid,
void SDDMMCsrHetero(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
int lhs_target, int rhs_target, const std::vector<dgl_type_t>& ufeat_eid,
const std::vector<dgl_type_t>& out_eid);
/**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const aten::COOMatrix& coo,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target);
void SDDMMCoo(
const std::string& op, const BcastOff& bcast, const aten::COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
/**
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo
format with heterograph support.
* @brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format
* with heterograph support.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
void SDDMMCooHetero(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid);
/**
* @brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
* @brief Generalized Dense Matrix-Matrix Multiplication according to relation
* types.
*/
template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A,
const NDArray B,
NDArray out,
const NDArray idx_a,
void GatherMM(
const NDArray A, const NDArray B, NDArray out, const NDArray idx_a,
const NDArray idx_b);
/**
* @brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
* @brief Generalized Dense Matrix-Matrix Multiplication according to relation
* types.
*/
template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A,
const NDArray B,
NDArray out,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c);
void GatherMMScatter(
const NDArray A, const NDArray B, NDArray out, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
/**
* @brief Generalized segmented dense Matrix-Matrix Multiplication.
*/
template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A,
const NDArray B,
NDArray out,
const NDArray seglen_A,
void SegmentMM(
const NDArray A, const NDArray B, NDArray out, const NDArray seglen_A,
bool a_trans, bool b_trans);
template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
const NDArray seglen);
void SegmentMMBackwardB(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
/**
* @brief Segment reduce.
*/
template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
void SegmentReduce(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
/**
* @brief Scatter Add on first dimension.
*/
template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out);
void ScatterAdd(NDArray feat, NDArray idx, NDArray out);
/**
* @brief Update gradients for reduce operator max and min on first dimension.
*/
template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out);
void UpdateGradMinMax_hetero(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
/**
* @brief Backward function of segment cmp.
*/
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat,
NDArray arg,
NDArray out);
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out);
/**
* @brief Sparse-sparse matrix multiplication
......@@ -200,9 +159,7 @@ void BackwardSegmentCmp(NDArray feat,
*/
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
const CSRMatrix& A,
NDArray A_weights,
const CSRMatrix& B,
const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
NDArray B_weights);
/**
......@@ -217,29 +174,22 @@ std::pair<CSRMatrix, NDArray> CSRMM(
*/
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights);
const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights);
/**
* @brief Edge_softmax_csr forward function on Csr format.
*/
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out);
void Edge_softmax_csr_forward(
const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
/**
* @brief Edge_softmax_csr backward function on Csr format.
*/
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out);
void Edge_softmax_csr_backward(
const std::string& op, const BcastOff& bcast, const aten::CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
} // namespace aten
} // namespace dgl
......
......@@ -7,15 +7,17 @@
#ifndef DGL_GRAPH_HETEROGRAPH_H_
#define DGL_GRAPH_HETEROGRAPH_H_
#include <dgl/runtime/shared_mem.h>
#include <dgl/base_heterograph.h>
#include <dgl/lazy.h>
#include <utility>
#include <string>
#include <vector>
#include <dgl/runtime/shared_mem.h>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <memory>
#include <utility>
#include <vector>
#include "./unit_graph.h"
#include "shared_mem_manager.h"
......@@ -25,8 +27,7 @@ namespace dgl {
class HeteroGraph : public BaseHeteroGraph {
public:
HeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type = {});
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
......@@ -46,31 +47,21 @@ class HeteroGraph : public BaseHeteroGraph {
LOG(FATAL) << "Bipartite graph is not mutable.";
}
void Clear() override {
LOG(FATAL) << "Bipartite graph is not mutable.";
}
void Clear() override { LOG(FATAL) << "Bipartite graph is not mutable."; }
DGLDataType DataType() const override {
return relation_graphs_[0]->DataType();
}
DGLContext Context() const override {
return relation_graphs_[0]->Context();
}
DGLContext Context() const override { return relation_graphs_[0]->Context(); }
bool IsPinned() const override {
return relation_graphs_[0]->IsPinned();
}
bool IsPinned() const override { return relation_graphs_[0]->IsPinned(); }
uint8_t NumBits() const override {
return relation_graphs_[0]->NumBits();
}
uint8_t NumBits() const override { return relation_graphs_[0]->NumBits(); }
bool IsMultigraph() const override;
bool IsReadonly() const override {
return true;
}
bool IsReadonly() const override { return true; }
uint64_t NumVertices(dgl_type_t vtype) const override {
CHECK(meta_graph_->HasVertex(vtype)) << "Invalid vertex type: " << vtype;
......@@ -91,11 +82,13 @@ class HeteroGraph : public BaseHeteroGraph {
BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;
bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
return GetRelationGraph(etype)->HasEdgeBetween(0, src, dst);
}
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
return GetRelationGraph(etype)->HasEdgesBetween(0, src_ids, dst_ids);
}
......@@ -111,15 +104,18 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->EdgeId(0, src, dst);
}
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIdsAll(0, src, dst);
}
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const override {
return GetRelationGraph(etype)->EdgeIdsOne(0, src, dst);
}
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const override {
return GetRelationGraph(etype)->FindEdge(0, eid);
}
......@@ -143,7 +139,8 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->OutEdges(0, vids);
}
EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const override {
return GetRelationGraph(etype)->Edges(0, order);
}
......@@ -180,7 +177,7 @@ class HeteroGraph : public BaseHeteroGraph {
}
std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override {
dgl_type_t etype, bool transpose, const std::string& fmt) const override {
return GetRelationGraph(etype)->GetAdj(0, transpose, fmt);
}
......@@ -196,7 +193,8 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(etype)->GetCSRMatrix(0);
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return GetRelationGraph(etype)->SelectFormat(0, preferred_formats);
}
......@@ -208,14 +206,17 @@ class HeteroGraph : public BaseHeteroGraph {
return GetRelationGraph(0)->GetCreatedFormats();
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;
HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const override;
HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
const std::vector<IdArray>& eids,
bool preserve_nodes = false) const override;
HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override;
FlattenedHeteroGraphPtr Flatten(const std::vector<dgl_type_t>& etypes) const override;
FlattenedHeteroGraphPtr Flatten(
const std::vector<dgl_type_t>& etypes) const override;
GraphPtr AsImmutableGraph() const override;
......@@ -229,25 +230,22 @@ class HeteroGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/** @brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx);
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext& ctx);
/**
* @brief Pin all relation graphs of the current graph.
* @note The graph will be pinned inplace. Behavior depends on the current context,
* kDGLCPU: will be pinned;
* IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
* @note The graph will be pinned inplace. Behavior depends on the current
* context, kDGLCPU: will be pinned; IsPinned: directly return; kDGLCUDA:
* invalid, will throw an error. The context check is deferred to pinning the
* NDArray.
*/
void PinMemory_() override;
/**
* @brief Unpin all relation graphs of the current graph.
* @note The graph will be unpinned inplace. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
* @note The graph will be unpinned inplace. Behavior depends on the current
* context, IsPinned: will be unpinned; others: directly return. The context
* check is deferred to unpinning the NDArray.
*/
void UnpinMemory_();
......@@ -259,17 +257,21 @@ class HeteroGraph : public BaseHeteroGraph {
/** @brief Copy the data to shared memory.
*
* Also save names of node types and edge types of the HeteroGraph object to shared memory
* Also save names of node types and edge types of the HeteroGraph object to
* shared memory
*/
static HeteroGraphPtr CopyToSharedMem(
HeteroGraphPtr g, const std::string& name, const std::vector<std::string>& ntypes,
const std::vector<std::string>& etypes, const std::set<std::string>& fmts);
HeteroGraphPtr g, const std::string& name,
const std::vector<std::string>& ntypes,
const std::vector<std::string>& etypes,
const std::set<std::string>& fmts);
/** @brief Create a heterograph from
* \return the HeteroGraphPtr, names of node types, names of edge types
*/
static std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
CreateFromSharedMem(const std::string &name);
static std::tuple<
HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
CreateFromSharedMem(const std::string& name);
/** @brief Creat a LineGraph of self */
HeteroGraphPtr LineGraph(bool backtracking) const;
......@@ -294,7 +296,8 @@ class HeteroGraph : public BaseHeteroGraph {
/** @brief The shared memory object for meta info*/
std::shared_ptr<runtime::SharedMemory> shared_mem_;
/** @brief The name of the shared memory. Return empty string if it is not in shared memory. */
/** @brief The name of the shared memory. Return empty string if it is not in
* shared memory. */
std::string SharedMemName() const;
/** @brief template class for Flatten operation
......@@ -304,15 +307,14 @@ class HeteroGraph : public BaseHeteroGraph {
* @return pointer of FlattenedHeteroGraphh
*/
template <class IdType>
FlattenedHeteroGraphPtr FlattenImpl(const std::vector<dgl_type_t>& etypes) const;
FlattenedHeteroGraphPtr FlattenImpl(
const std::vector<dgl_type_t>& etypes) const;
};
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::HeteroGraph, true);
} // namespace dmlc
#endif // DGL_GRAPH_HETEROGRAPH_H_
......@@ -4,11 +4,12 @@
* @brief Definition of neighborhood-based sampler APIs.
*/
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include <dgl/aten/macro.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/sampling/neighbor.h>
#include "../../../c_api_common.h"
#include "../../unit_graph.h"
......@@ -19,18 +20,15 @@ namespace dgl {
namespace sampling {
HeteroSubgraph ExcludeCertainEdges(
const HeteroSubgraph& sg,
const std::vector<IdArray>& exclude_edges) {
const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges) {
HeteroGraphPtr hg_view = HeteroGraphRef(sg.graph).sptr();
std::vector<IdArray> remain_induced_edges(hg_view->NumEdgeTypes());
std::vector<IdArray> remain_edges(hg_view->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < hg_view->NumEdgeTypes(); ++etype) {
IdArray edge_ids = Range(0,
sg.induced_edges[etype]->shape[0],
sg.induced_edges[etype]->dtype.bits,
sg.induced_edges[etype]->ctx);
IdArray edge_ids = Range(
0, sg.induced_edges[etype]->shape[0],
sg.induced_edges[etype]->dtype.bits, sg.induced_edges[etype]->ctx);
if (exclude_edges[etype].GetSize() == 0 || edge_ids.GetSize() == 0) {
remain_edges[etype] = edge_ids;
remain_induced_edges[etype] = sg.induced_edges[etype];
......@@ -40,13 +38,14 @@ HeteroSubgraph ExcludeCertainEdges(
IdType* idx_data = edge_ids.Ptr<IdType>();
IdType* induced_edges_data = sg.induced_edges[etype].Ptr<IdType>();
const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
std::sort(exclude_edges[etype].Ptr<IdType>(),
std::sort(
exclude_edges[etype].Ptr<IdType>(),
exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
IdType outId = 0;
for (IdType i = 0; i != sg.induced_edges[etype]->shape[0]; ++i) {
if (!std::binary_search(exclude_edges_data,
exclude_edges_data + exclude_edges_len,
if (!std::binary_search(
exclude_edges_data, exclude_edges_data + exclude_edges_len,
induced_edges_data[i])) {
induced_edges_data[outId] = induced_edges_data[i];
idx_data[outId] = idx_data[i];
......@@ -54,7 +53,8 @@ HeteroSubgraph ExcludeCertainEdges(
}
}
remain_edges[etype] = aten::IndexSelect(edge_ids, 0, outId);
remain_induced_edges[etype] = aten::IndexSelect(sg.induced_edges[etype], 0, outId);
remain_induced_edges[etype] =
aten::IndexSelect(sg.induced_edges[etype], 0, outId);
});
}
HeteroSubgraph subg = hg_view->EdgeSubgraph(remain_edges, true);
......@@ -63,14 +63,10 @@ HeteroSubgraph ExcludeCertainEdges(
}
HeteroSubgraph SampleNeighbors(
const HeteroGraphPtr hg,
const std::vector<IdArray>& nodes,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
const std::vector<int64_t>& fanouts, EdgeDir dir,
const std::vector<NDArray>& prob_or_mask,
const std::vector<IdArray>& exclude_edges,
bool replace) {
const std::vector<IdArray>& exclude_edges, bool replace) {
// sanity check
CHECK_EQ(nodes.size(), hg->NumVertexTypes())
<< "Number of node ID tensors must match the number of node types.";
......@@ -87,42 +83,46 @@ HeteroSubgraph SampleNeighbors(
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype];
const IdArray nodes_ntype =
nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0 || fanouts[etype] == 0) {
// Nothing to sample for this etype, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
hg->DataType(), ctx);
induced_edges[etype] = aten::NullArray(hg->DataType(), ctx);
} else {
COOMatrix sampled_coo;
// sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
switch (avail_fmt) {
case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes_ntype, fanouts[etype], prob_or_mask[etype], replace));
aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
fanouts[etype], prob_or_mask[etype], replace));
} else {
sampled_coo = aten::COORowWiseSampling(
hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace);
hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype],
prob_or_mask[etype], replace);
}
break;
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
CHECK(dir == EdgeDir::kOut)
<< "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace);
hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype],
prob_or_mask[etype], replace);
break;
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype], prob_or_mask[etype], replace);
hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype],
prob_or_mask[etype], replace);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
......@@ -130,14 +130,15 @@ HeteroSubgraph SampleNeighbors(
}
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
induced_edges[etype] = sampled_coo.data;
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.graph =
CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges);
if (!exclude_edges.empty()) {
......@@ -147,15 +148,11 @@ HeteroSubgraph SampleNeighbors(
}
HeteroSubgraph SampleNeighborsEType(
const HeteroGraphPtr hg,
const IdArray nodes,
const HeteroGraphPtr hg, const IdArray nodes,
const std::vector<int64_t>& eid2etype_offset,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const std::vector<FloatArray>& prob,
bool replace,
const std::vector<int64_t>& fanouts, EdgeDir dir,
const std::vector<FloatArray>& prob, bool replace,
bool rowwise_etype_sorted) {
CHECK_EQ(1, hg->NumVertexTypes())
<< "SampleNeighborsEType only work with homogeneous graph";
CHECK_EQ(1, hg->NumEdgeTypes())
......@@ -178,39 +175,39 @@ HeteroSubgraph SampleNeighborsEType(
}
if (num_nodes == 0 || (same_fanout && fanout_value == 0)) {
subrels[etype] = UnitGraph::Empty(1,
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
subrels[etype] = UnitGraph::Empty(
1, hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
induced_edges[etype] = aten::NullArray();
} else {
COOMatrix sampled_coo;
// sample from graph
// the edge type is stored in etypes
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
switch (avail_fmt) {
case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWisePerEtypeSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes, eid2etype_offset, fanouts, prob, replace));
aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes,
eid2etype_offset, fanouts, prob, replace));
} else {
sampled_coo = aten::COORowWisePerEtypeSampling(
hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob, replace);
hg->GetCOOMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
replace);
}
break;
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSRMatrix(etype), nodes, eid2etype_offset,
fanouts, prob, replace, rowwise_etype_sorted);
hg->GetCSRMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
replace, rowwise_etype_sorted);
break;
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWisePerEtypeSampling(
hg->GetCSCMatrix(etype), nodes, eid2etype_offset,
fanouts, prob, replace, rowwise_etype_sorted);
hg->GetCSCMatrix(etype), nodes, eid2etype_offset, fanouts, prob,
replace, rowwise_etype_sorted);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
......@@ -218,25 +215,23 @@ HeteroSubgraph SampleNeighborsEType(
}
subrels[etype] = UnitGraph::CreateFromCOO(
1, sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
1, sampled_coo.num_rows, sampled_coo.num_cols, sampled_coo.row,
sampled_coo.col);
induced_edges[etype] = sampled_coo.data;
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.graph =
CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges);
return ret;
}
HeteroSubgraph SampleNeighborsTopk(
const HeteroGraphPtr hg,
const std::vector<IdArray>& nodes,
const std::vector<int64_t>& k,
EdgeDir dir,
const std::vector<FloatArray>& weight,
bool ascending) {
const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
const std::vector<int64_t>& k, EdgeDir dir,
const std::vector<FloatArray>& weight, bool ascending) {
// sanity check
CHECK_EQ(nodes.size(), hg->NumVertexTypes())
<< "Number of node ID tensors must match the number of node types.";
......@@ -251,77 +246,79 @@ HeteroSubgraph SampleNeighborsTopk(
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype];
const IdArray nodes_ntype =
nodes[(dir == EdgeDir::kOut) ? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0 || k[etype] == 0) {
// Nothing to sample for this etype, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
induced_edges[etype] = aten::NullArray();
} else {
// sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo;
switch (avail_fmt) {
case SparseFormat::kCOO:
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseTopk(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes_ntype, k[etype], weight[etype], ascending));
aten::COOTranspose(hg->GetCOOMatrix(etype)), nodes_ntype,
k[etype], weight[etype], ascending));
} else {
sampled_coo = aten::COORowWiseTopk(
hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype],
ascending);
}
break;
case SparseFormat::kCSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
CHECK(dir == EdgeDir::kOut)
<< "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseTopk(
hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype],
ascending);
break;
case SparseFormat::kCSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseTopk(
hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype],
ascending);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
LOG(FATAL) << "Unsupported sparse format.";
}
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
induced_edges[etype] = sampled_coo.data;
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.graph =
CreateHeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges);
return ret;
}
HeteroSubgraph SampleNeighborsBiased(
const HeteroGraphPtr hg,
const IdArray& nodes,
const int64_t fanout,
const NDArray& bias,
const NDArray& tag_offset,
const EdgeDir dir,
const bool replace
) {
CHECK_EQ(hg->NumEdgeTypes(), 1) << "Only homogeneous or bipartite graphs are supported";
const HeteroGraphPtr hg, const IdArray& nodes, const int64_t fanout,
const NDArray& bias, const NDArray& tag_offset, const EdgeDir dir,
const bool replace) {
CHECK_EQ(hg->NumEdgeTypes(), 1)
<< "Only homogeneous or bipartite graphs are supported";
auto pair = hg->meta_graph()->FindEdge(0);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const dgl_type_t nodes_ntype = (dir == EdgeDir::kOut) ? src_vtype : dst_vtype;
// sanity check
CHECK_EQ(tag_offset->ndim, 2) << "The shape of tag_offset should be [num_nodes, num_tags + 1]";
CHECK_EQ(tag_offset->ndim, 2)
<< "The shape of tag_offset should be [num_nodes, num_tags + 1]";
CHECK_EQ(tag_offset->shape[0], hg->NumVertices(nodes_ntype))
<< "The shape of tag_offset should be [num_nodes, num_tags + 1]";
CHECK_EQ(tag_offset->shape[1], bias->shape[0] + 1)
......@@ -335,13 +332,12 @@ HeteroSubgraph SampleNeighborsBiased(
// Nothing to sample for this etype, create a placeholder relation graph
subrel = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
hg->NumVertices(src_vtype), hg->NumVertices(dst_vtype), hg->DataType(),
hg->Context());
induced_edges = aten::NullArray();
} else {
// sample from one relation graph
const auto req_fmt = (dir == EdgeDir::kOut)? CSR_CODE : CSC_CODE;
const auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
const auto created_fmt = hg->GetCreatedFormats();
COOMatrix sampled_coo;
......@@ -361,23 +357,25 @@ HeteroSubgraph SampleNeighborsBiased(
LOG(FATAL) << "Unsupported sparse format.";
}
subrel = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows, sampled_coo.num_cols,
sampled_coo.row, sampled_coo.col);
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo.num_rows,
sampled_coo.num_cols, sampled_coo.row, sampled_coo.col);
induced_edges = sampled_coo.data;
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType());
ret.graph =
CreateHeteroGraph(hg->meta_graph(), {subrel}, hg->NumVerticesPerType());
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = {induced_edges};
return ret;
}
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
IdArray nodes = args[1];
const std::vector<int64_t>& eid2etype_offset = ListValueToVector<int64_t>(args[2]);
const std::vector<int64_t>& eid2etype_offset =
ListValueToVector<int64_t>(args[2]);
IdArray fanout = args[3];
const std::string dir_str = args[4];
const auto& prob = ListValueToVector<FloatArray>(args[5]);
......@@ -386,18 +384,19 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsEType")
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
CHECK_INT64(fanout, "fanout");
std::vector<int64_t> fanout_vec = fanout.ToVector<int64_t>();
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsEType(
hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace, rowwise_etype_sorted);
hg.sptr(), nodes, eid2etype_offset, fanout_vec, dir, prob, replace,
rowwise_etype_sorted);
*rv = HeteroSubgraphRef(subg);
});
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
IdArray fanouts_array = args[2];
......@@ -409,7 +408,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighbors(
......@@ -419,7 +418,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
});
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
IdArray k_array = args[2];
......@@ -430,7 +429,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsTopk(
......@@ -440,7 +439,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
});
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsBiased")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const IdArray nodes = args[1];
const int64_t fanout = args[2];
......@@ -451,7 +450,7 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsBiased")
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsBiased(
......
......@@ -65,7 +65,8 @@ DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true);
namespace dgl {
namespace serialize {
bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data,
bool SaveDGLGraphs(
std::string filename, List<GraphData> graph_data,
std::vector<NamedTensor> labels_list) {
auto fs = std::unique_ptr<SeekStream>(dynamic_cast<SeekStream *>(
SeekStream::Create(filename.c_str(), "w", true)));
......@@ -110,8 +111,9 @@ bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data,
return true;
}
StorageMetaData LoadDGLGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list, bool onlyMeta) {
StorageMetaData LoadDGLGraphs(
const std::string &filename, std::vector<dgl_id_t> idx_list,
bool onlyMeta) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), true));
CHECK(fs) << "Filename is invalid";
......@@ -179,8 +181,8 @@ StorageMetaData LoadDGLGraphs(const std::string &filename,
return metadata;
}
void GraphDataObject::SetData(ImmutableGraphPtr gptr,
Map<std::string, Value> node_tensors,
void GraphDataObject::SetData(
ImmutableGraphPtr gptr, Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors) {
this->gptr = gptr;
......@@ -243,21 +245,18 @@ ImmutableGraphPtr ToImmutableGraph(GraphPtr g) {
IdArray dsts_array = earray.dst;
bool row_sorted, col_sorted;
std::tie(row_sorted, col_sorted) = COOIsSorted(
aten::COOMatrix(mgr->NumVertices(), mgr->NumVertices(), srcs_array,
dsts_array));
std::tie(row_sorted, col_sorted) = COOIsSorted(aten::COOMatrix(
mgr->NumVertices(), mgr->NumVertices(), srcs_array, dsts_array));
ImmutableGraphPtr imgptr =
ImmutableGraph::CreateFromCOO(mgr->NumVertices(), srcs_array, dsts_array,
row_sorted, col_sorted);
ImmutableGraphPtr imgptr = ImmutableGraph::CreateFromCOO(
mgr->NumVertices(), srcs_array, dsts_array, row_sorted, col_sorted);
return imgptr;
}
}
void StorageMetaDataObject::SetMetaData(dgl_id_t num_graph,
std::vector<int64_t> nodes_num_list,
std::vector<int64_t> edges_num_list,
std::vector<NamedTensor> labels_list) {
void StorageMetaDataObject::SetMetaData(
dgl_id_t num_graph, std::vector<int64_t> nodes_num_list,
std::vector<int64_t> edges_num_list, std::vector<NamedTensor> labels_list) {
this->num_graph = num_graph;
this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list)));
this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list)));
......
......@@ -147,7 +147,8 @@ DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_LoadGraphFiles_V1")
DGL_REGISTER_GLOBAL("data.graph_serialize._CAPI_DGLAsHeteroGraph")
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0];
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
ImmutableGraphPtr ig =
std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ig) << "graph is not readonly";
*rv = HeteroGraphRef(ig->AsHeteroGraph());
});
......
......@@ -48,8 +48,8 @@
#include <vector>
#include "../heterograph.h"
#include "./graph_serialize.h"
#include "./dglstream.h"
#include "./graph_serialize.h"
#include "dmlc/memory_io.h"
namespace dgl {
......@@ -61,9 +61,9 @@ using dmlc::Stream;
using dmlc::io::FileSystem;
using dmlc::io::URI;
bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata,
const std::vector<NamedTensor> &nd_list,
dgl_format_code_t formats) {
bool SaveHeteroGraphs(
std::string filename, List<HeteroGraphData> hdata,
const std::vector<NamedTensor> &nd_list, dgl_format_code_t formats) {
auto fs = std::unique_ptr<DGLStream>(
DGLStream::Create(filename.c_str(), "w", false, formats));
CHECK(fs->IsValid()) << "File name " << filename << " is not a valid name";
......@@ -120,8 +120,8 @@ bool SaveHeteroGraphs(std::string filename, List<HeteroGraphData> hdata,
return true;
}
std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename,
std::vector<dgl_id_t> idx_list) {
std::vector<HeteroGraphData> LoadHeteroGraphs(
const std::string &filename, std::vector<dgl_id_t> idx_list) {
auto fs = std::unique_ptr<SeekStream>(
SeekStream::CreateForRead(filename.c_str(), false));
CHECK(fs) << "File name " << filename << " is not a valid name";
......@@ -213,8 +213,8 @@ DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_MakeHeteroGraphData")
List<Map<std::string, Value>> edata = args[2];
List<Value> ntype_names = args[3];
List<Value> etype_names = args[4];
*rv = HeteroGraphData::Create(hg.sptr(), ndata, edata, ntype_names,
etype_names);
*rv = HeteroGraphData::Create(
hg.sptr(), ndata, edata, ntype_names, etype_names);
});
DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData")
......@@ -224,7 +224,7 @@ DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData")
Map<std::string, Value> nd_map = args[2];
List<Value> formats = args[3];
std::vector<SparseFormat> formats_vec;
for (const auto& val : formats) {
for (const auto &val : formats) {
formats_vec.push_back(ParseSparseFormat(val->data));
}
const auto formats_code = SparseFormatsToCode(formats_vec);
......@@ -233,7 +233,8 @@ DGL_REGISTER_GLOBAL("data.heterograph_serialize._CAPI_SaveHeteroGraphData")
NDArray ndarray = static_cast<NDArray>(kv.second->data);
nd_list.emplace_back(kv.first, ndarray);
}
*rv = dgl::serialize::SaveHeteroGraphs(filename, hgdata, nd_list, formats_code);
*rv = dgl::serialize::SaveHeteroGraphs(
filename, hgdata, nd_list, formats_code);
});
DGL_REGISTER_GLOBAL(
......
......@@ -3,13 +3,14 @@
* @file graph/unit_graph.cc
* @brief UnitGraph graph implementation
*/
#include "./unit_graph.h"
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h>
#include <dgl/lazy.h>
#include "../c_api_common.h"
#include "./unit_graph.h"
#include "./serialize/dglstream.h"
namespace dgl {
......@@ -67,10 +68,10 @@ class UnitGraph::COO : public BaseHeteroGraph {
: BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
CHECK_EQ(src->shape[0], dst->shape[0])
<< "Input arrays should have the same length.";
adj_ = aten::COOMatrix{num_src, num_dst, src, dst,
NullArray(),
row_sorted, col_sorted};
NullArray(), row_sorted, col_sorted};
}
COO(GraphPtr metagraph, const aten::COOMatrix& coo)
......@@ -83,26 +84,19 @@ class UnitGraph::COO : public BaseHeteroGraph {
COO() {
// set magic num_rows/num_cols to mark it as undefined
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
// supported
adj_.num_rows = -1;
adj_.num_cols = -1;
};
bool defined() const {
return (adj_.num_rows >= 0) && (adj_.num_cols >= 0);
}
bool defined() const { return (adj_.num_rows >= 0) && (adj_.num_cols >= 0); }
inline dgl_type_t SrcType() const {
return 0;
}
inline dgl_type_t SrcType() const { return 0; }
inline dgl_type_t DstType() const {
return NumVertexTypes() == 1? 0 : 1;
}
inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
inline dgl_type_t EdgeType() const {
return 0;
}
inline dgl_type_t EdgeType() const { return 0; }
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
......@@ -122,67 +116,44 @@ class UnitGraph::COO : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void Clear() override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
DGLDataType DataType() const override {
return adj_.row->dtype;
}
DGLDataType DataType() const override { return adj_.row->dtype; }
DGLContext Context() const override {
return adj_.row->ctx;
}
DGLContext Context() const override { return adj_.row->ctx; }
bool IsPinned() const override {
return adj_.is_pinned;
}
bool IsPinned() const override { return adj_.is_pinned; }
uint8_t NumBits() const override {
return adj_.row->dtype.bits;
}
uint8_t NumBits() const override { return adj_.row->dtype.bits; }
COO AsNumBits(uint8_t bits) const {
if (NumBits() == bits)
return *this;
if (NumBits() == bits) return *this;
COO ret(
meta_graph_,
adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits));
meta_graph_, adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.col, bits));
return ret;
}
COO CopyTo(const DGLContext &ctx) const {
if (Context() == ctx)
return *this;
COO CopyTo(const DGLContext& ctx) const {
if (Context() == ctx) return *this;
return COO(meta_graph_, adj_.CopyTo(ctx));
}
/** @brief Pin the adj_: COOMatrix of the COO graph. */
void PinMemory_() {
adj_.PinMemory_();
}
void PinMemory_() { adj_.PinMemory_(); }
/** @brief Unpin the adj_: COOMatrix of the COO graph. */
void UnpinMemory_() {
adj_.UnpinMemory_();
}
void UnpinMemory_() { adj_.UnpinMemory_(); }
/** @brief Record stream for the adj_: COOMatrix of the COO graph. */
void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream);
}
bool IsMultigraph() const override {
return aten::COOHasDuplicate(adj_);
}
bool IsMultigraph() const override { return aten::COOHasDuplicate(adj_); }
bool IsReadonly() const override {
return true;
}
bool IsReadonly() const override { return true; }
uint64_t NumVertices(dgl_type_t vtype) const override {
if (vtype == SrcType()) {
......@@ -208,13 +179,15 @@ class UnitGraph::COO : public BaseHeteroGraph {
return {};
}
bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::COOIsNonZero(adj_, src, dst);
}
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::COOIsNonZero(adj_, src_ids, dst_ids);
......@@ -236,18 +209,21 @@ class UnitGraph::COO : public BaseHeteroGraph {
return aten::COOGetAllData(adj_, src, dst);
}
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const override {
CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst);
return EdgeArray{arrs[0], arrs[1], arrs[2]};
}
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const override {
return aten::COOGetData(adj_, src, dst);
}
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const override {
CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid;
const dgl_id_t src = aten::IndexSelect<int64_t>(adj_.row, eid);
const dgl_id_t dst = aten::IndexSelect<int64_t>(adj_.col, eid);
......@@ -256,18 +232,19 @@ class UnitGraph::COO : public BaseHeteroGraph {
EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
BUG_IF_FAIL(aten::IsNullArray(adj_.data)) <<
"FindEdges requires the internal COO matrix not having EIDs.";
return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids),
BUG_IF_FAIL(aten::IsNullArray(adj_.data))
<< "FindEdges requires the internal COO matrix not having EIDs.";
return EdgeArray{
aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),
eids};
}
EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
IdArray ret_src, ret_eid;
std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(
aten::COOTranspose(adj_), vid);
IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
std::tie(ret_eid, ret_src) =
aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), vid);
IdArray ret_dst =
aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid};
}
......@@ -281,7 +258,8 @@ class UnitGraph::COO : public BaseHeteroGraph {
EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override {
IdArray ret_dst, ret_eid;
std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
IdArray ret_src =
aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid};
}
......@@ -292,10 +270,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return EdgeArray{row, coosubmat.col, coosubmat.data};
}
EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const override {
CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \""
<< order << "\".";
<< "COO only support Edges of order \"eid\", but got \"" << order
<< "\".";
IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context());
return EdgeArray{adj_.row, adj_.col, rst_eid};
}
......@@ -341,7 +320,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
}
std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override {
dgl_type_t etype, bool transpose, const std::string& fmt) const override {
CHECK(fmt == "coo") << "Not valid adj format request.";
if (transpose) {
return {aten::HStack(adj_.col, adj_.row)};
......@@ -350,9 +329,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
}
}
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
return adj_;
}
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { return adj_; }
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for COO graph";
......@@ -364,7 +341,8 @@ class UnitGraph::COO : public BaseHeteroGraph {
return aten::CSRMatrix();
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::kCOO;
}
......@@ -379,8 +357,10 @@ class UnitGraph::COO : public BaseHeteroGraph {
return 0;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes())
<< "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
......@@ -388,15 +368,16 @@ class UnitGraph::COO : public BaseHeteroGraph {
const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids);
DGLContext ctx = aten::GetContextOf(vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
subg.graph = std::make_shared<COO>(meta_graph(), submat.num_rows, submat.num_cols,
submat.row, submat.col);
subg.graph = std::make_shared<COO>(
meta_graph(), submat.num_rows, submat.num_cols, submat.row, submat.col);
subg.induced_vertices = vids;
subg.induced_edges.emplace_back(submat.data);
return subg;
}
HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
const std::vector<IdArray>& eids,
bool preserve_nodes = false) const override {
CHECK_EQ(eids.size(), 1) << "Edge type number mismatch.";
HeteroSubgraph subg;
if (!preserve_nodes) {
......@@ -417,7 +398,8 @@ class UnitGraph::COO : public BaseHeteroGraph {
subg.induced_vertices.emplace_back(
aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context()));
subg.graph = std::make_shared<COO>(
meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst);
meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src,
new_dst);
subg.induced_edges = eids;
}
return subg;
......@@ -428,13 +410,11 @@ class UnitGraph::COO : public BaseHeteroGraph {
return nullptr;
}
aten::COOMatrix adj() const {
return adj_;
}
aten::COOMatrix adj() const { return adj_; }
/**
* @brief Determines whether the graph is "hypersparse", i.e. having significantly more
* nodes than edges.
* @brief Determines whether the graph is "hypersparse", i.e. having
* significantly more nodes than edges.
*/
bool IsHypersparse() const {
return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) &&
......@@ -470,14 +450,17 @@ class UnitGraph::COO : public BaseHeteroGraph {
/** @brief CSR graph */
class UnitGraph::CSR : public BaseHeteroGraph {
public:
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids)
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids)
: BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
if (aten::IsValidIdArray(edge_ids))
CHECK((indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids))
<< "edge id arrays should have the same length as indices if not empty";
CHECK(
(indices->shape[0] == edge_ids->shape[0]) ||
aten::IsNullArray(edge_ids))
<< "edge id arrays should have the same length as indices if not "
"empty";
CHECK_EQ(num_src, indptr->shape[0] - 1)
<< "number of nodes do not match the length of indptr minus 1.";
......@@ -485,31 +468,23 @@ class UnitGraph::CSR : public BaseHeteroGraph {
}
CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) {
}
: BaseHeteroGraph(metagraph), adj_(csr) {}
CSR() {
// set magic num_rows/num_cols to mark it as undefined
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is
// supported
adj_.num_rows = -1;
adj_.num_cols = -1;
};
bool defined() const {
return (adj_.num_rows >= 0) || (adj_.num_cols >= 0);
}
bool defined() const { return (adj_.num_rows >= 0) || (adj_.num_cols >= 0); }
inline dgl_type_t SrcType() const {
return 0;
}
inline dgl_type_t SrcType() const { return 0; }
inline dgl_type_t DstType() const {
return NumVertexTypes() == 1? 0 : 1;
}
inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
inline dgl_type_t EdgeType() const {
return 0;
}
inline dgl_type_t EdgeType() const { return 0; }
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
......@@ -529,33 +504,22 @@ class UnitGraph::CSR : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void Clear() override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
DGLDataType DataType() const override {
return adj_.indices->dtype;
}
DGLDataType DataType() const override { return adj_.indices->dtype; }
DGLContext Context() const override {
return adj_.indices->ctx;
}
DGLContext Context() const override { return adj_.indices->ctx; }
bool IsPinned() const override {
return adj_.is_pinned;
}
bool IsPinned() const override { return adj_.is_pinned; }
uint8_t NumBits() const override {
return adj_.indices->dtype.bits;
}
uint8_t NumBits() const override { return adj_.indices->dtype.bits; }
CSR AsNumBits(uint8_t bits) const {
if (NumBits() == bits) {
return *this;
} else {
CSR ret(
meta_graph_,
adj_.num_rows, adj_.num_cols,
meta_graph_, adj_.num_rows, adj_.num_cols,
aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits));
......@@ -563,7 +527,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
}
}
CSR CopyTo(const DGLContext &ctx) const {
CSR CopyTo(const DGLContext& ctx) const {
if (Context() == ctx) {
return *this;
} else {
......@@ -572,27 +536,19 @@ class UnitGraph::CSR : public BaseHeteroGraph {
}
/** @brief Pin the adj_: CSRMatrix of the CSR graph. */
void PinMemory_() {
adj_.PinMemory_();
}
void PinMemory_() { adj_.PinMemory_(); }
/** @brief Unpin the adj_: CSRMatrix of the CSR graph. */
void UnpinMemory_() {
adj_.UnpinMemory_();
}
void UnpinMemory_() { adj_.UnpinMemory_(); }
/** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */
void RecordStream(DGLStreamHandle stream) override {
adj_.RecordStream(stream);
}
bool IsMultigraph() const override {
return aten::CSRHasDuplicate(adj_);
}
bool IsMultigraph() const override { return aten::CSRHasDuplicate(adj_); }
bool IsReadonly() const override {
return true;
}
bool IsReadonly() const override { return true; }
uint64_t NumVertices(dgl_type_t vtype) const override {
if (vtype == SrcType()) {
......@@ -618,13 +574,15 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return {};
}
bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override {
CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src;
CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst;
return aten::CSRIsNonZero(adj_, src, dst);
}
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override {
CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array.";
return aten::CSRIsNonZero(adj_, src_ids, dst_ids);
......@@ -646,18 +604,21 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return aten::CSRGetAllData(adj_, src, dst);
}
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override {
EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const override {
CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array.";
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst);
return EdgeArray{arrs[0], arrs[1], arrs[2]};
}
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override {
IdArray EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const override {
return aten::CSRGetData(adj_, src, dst);
}
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}
......@@ -681,7 +642,8 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid;
IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid);
IdArray ret_eid = aten::CSRGetRowData(adj_, vid);
IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
IdArray ret_src =
aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx);
return EdgeArray{ret_src, ret_dst, ret_eid};
}
......@@ -695,7 +657,8 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return EdgeArray{row, coosubmat.col, coosubmat.data};
}
EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override {
EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const override {
CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\".";
......@@ -770,7 +733,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
}
std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override {
dgl_type_t etype, bool transpose, const std::string& fmt) const override {
CHECK(!transpose && fmt == "csr") << "Not valid adj format request.";
return {adj_.indptr, adj_.indices, adj_.data};
}
......@@ -785,11 +748,10 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return aten::CSRMatrix();
}
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
return adj_;
}
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { return adj_; }
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::kCSR;
}
......@@ -804,8 +766,10 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return 0;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes())
<< "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array.";
CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array.";
......@@ -813,15 +777,17 @@ class UnitGraph::CSR : public BaseHeteroGraph {
const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids);
DGLContext ctx = aten::GetContextOf(vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx);
subg.graph = std::make_shared<CSR>(meta_graph(), submat.num_rows, submat.num_cols,
submat.indptr, submat.indices, sub_eids);
subg.graph = std::make_shared<CSR>(
meta_graph(), submat.num_rows, submat.num_cols, submat.indptr,
submat.indices, sub_eids);
subg.induced_vertices = vids;
subg.induced_edges.emplace_back(submat.data);
return subg;
}
HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
const std::vector<IdArray>& eids,
bool preserve_nodes = false) const override {
LOG(FATAL) << "Not enabled for CSR graph.";
return {};
}
......@@ -831,9 +797,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return nullptr;
}
aten::CSRMatrix adj() const {
return adj_;
}
aten::CSRMatrix adj() const { return adj_; }
bool Load(dmlc::Stream* fs) {
auto meta_imgraph = Serializer::make_shared<ImmutableGraph>();
......@@ -861,21 +825,13 @@ class UnitGraph::CSR : public BaseHeteroGraph {
//
//////////////////////////////////////////////////////////
DGLDataType UnitGraph::DataType() const {
return GetAny()->DataType();
}
DGLDataType UnitGraph::DataType() const { return GetAny()->DataType(); }
DGLContext UnitGraph::Context() const {
return GetAny()->Context();
}
DGLContext UnitGraph::Context() const { return GetAny()->Context(); }
bool UnitGraph::IsPinned() const {
return GetAny()->IsPinned();
}
bool UnitGraph::IsPinned() const { return GetAny()->IsPinned(); }
uint8_t UnitGraph::NumBits() const {
return GetAny()->NumBits();
}
uint8_t UnitGraph::NumBits() const { return GetAny()->NumBits(); }
bool UnitGraph::IsMultigraph() const {
const SparseFormat fmt = SelectFormat(CSC_CODE);
......@@ -910,7 +866,8 @@ BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
return aten::LT(vids, NumVertices(vtype));
}
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
bool UnitGraph::HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(CSC_CODE);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC)
......@@ -953,7 +910,8 @@ IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
return ptr->EdgeId(etype, src, dst);
}
EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const {
EdgeArray UnitGraph::EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) {
......@@ -964,7 +922,8 @@ EdgeArray UnitGraph::EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) cons
}
}
IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const {
IdArray UnitGraph::EdgeIdsOne(
dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(CSR_CODE);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::kCSC) {
......@@ -974,7 +933,8 @@ IdArray UnitGraph::EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const
}
}
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(
dgl_type_t etype, dgl_id_t eid) const {
const SparseFormat fmt = SelectFormat(COO_CODE);
const auto ptr = GetFormat(fmt);
return ptr->FindEdge(etype, eid);
......@@ -1020,7 +980,7 @@ EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
return ptr->OutEdges(etype, vids);
}
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string& order) const {
SparseFormat fmt;
if (order == std::string("eid")) {
fmt = SelectFormat(COO_CODE);
......@@ -1117,14 +1077,15 @@ DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
}
std::vector<IdArray> UnitGraph::GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const {
// TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
// src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
// is equal to in edge CSR.
// We have this behavior because previously we use framework's SPMM and we don't cache
// reverse adj. This is not intuitive and also not consistent with networkx's
// to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
// behavior and make row for src and col for dst.
dgl_type_t etype, bool transpose, const std::string& fmt) const {
// TODO(minjie): Our current semantics of adjacency matrix is row for dst
// nodes and col for src nodes. Therefore, we need to flip the transpose flag.
// For example,
// transpose=False is equal to in edge CSR. We have this behavior because
// previously we use framework's SPMM and we don't cache reverse adj. This
// is not intuitive and also not consistent with networkx's
// to_scipy_sparse_matrix. With the upcoming custom kernel change, we should
// change the behavior and make row for src and col for dst.
if (fmt == std::string("csr")) {
return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr")
: GetInCSR()->GetAdj(etype, false, "csr");
......@@ -1136,7 +1097,8 @@ std::vector<IdArray> UnitGraph::GetAdj(
}
}
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
HeteroSubgraph UnitGraph::VertexSubgraph(
const std::vector<IdArray>& vids) const {
// We prefer to generate a subgraph from out-csr.
SparseFormat fmt = SelectFormat(CSR_CODE);
HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
......@@ -1160,7 +1122,8 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
return ret;
}
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.graph =
HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.induced_vertices = std::move(sg.induced_vertices);
ret.induced_edges = std::move(sg.induced_edges);
return ret;
......@@ -1190,80 +1153,67 @@ HeteroSubgraph UnitGraph::EdgeSubgraph(
return ret;
}
ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.graph =
HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo));
ret.induced_vertices = std::move(sg.induced_vertices);
ret.induced_edges = std::move(sg.induced_edges);
return ret;
}
HeteroGraphPtr UnitGraph::CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col,
bool row_sorted, bool col_sorted,
dgl_format_code_t formats) {
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(num_src, num_dst);
if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, num_src, num_dst, row, col,
row_sorted, col_sorted));
COOPtr coo(new COO(mg, num_src, num_dst, row, col, row_sorted, col_sorted));
return HeteroGraphPtr(
new UnitGraph(mg, nullptr, nullptr, coo, formats));
return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
}
HeteroGraphPtr UnitGraph::CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats) {
int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(mat.num_rows, mat.num_cols);
if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, mat));
return HeteroGraphPtr(
new UnitGraph(mg, nullptr, nullptr, coo, formats));
return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats));
}
HeteroGraphPtr UnitGraph::CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(num_src, num_dst);
if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
}
HeteroGraphPtr UnitGraph::CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats) {
int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(mat.num_rows, mat.num_cols);
if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, mat));
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
}
HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(num_src, num_dst);
if (num_vtypes == 1) CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
}
HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats) {
int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(mat.num_rows, mat.num_cols);
if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, mat));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
......@@ -1275,18 +1225,21 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
} else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg);
CSRPtr new_incsr =
(bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr;
CSRPtr new_outcsr =
(bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr;
COOPtr new_coo =
(bg->coo_->defined())? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr;
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
CSRPtr new_incsr = (bg->in_csr_->defined())
? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits)))
: nullptr;
CSRPtr new_outcsr = (bg->out_csr_->defined())
? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits)))
: nullptr;
COOPtr new_coo = (bg->coo_->defined())
? COOPtr(new COO(bg->coo_->AsNumBits(bits)))
: nullptr;
return HeteroGraphPtr(new UnitGraph(
g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
}
}
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
if (ctx == g->Context()) {
return g;
} else {
......@@ -1301,54 +1254,43 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
COOPtr new_coo = (bg->coo_->defined())
? COOPtr(new COO(bg->coo_->CopyTo(ctx)))
: nullptr;
return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
return HeteroGraphPtr(new UnitGraph(
g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_));
}
}
void UnitGraph::PinMemory_() {
if (this->in_csr_->defined())
this->in_csr_->PinMemory_();
if (this->out_csr_->defined())
this->out_csr_->PinMemory_();
if (this->coo_->defined())
this->coo_->PinMemory_();
if (this->in_csr_->defined()) this->in_csr_->PinMemory_();
if (this->out_csr_->defined()) this->out_csr_->PinMemory_();
if (this->coo_->defined()) this->coo_->PinMemory_();
}
void UnitGraph::UnpinMemory_() {
if (this->in_csr_->defined())
this->in_csr_->UnpinMemory_();
if (this->out_csr_->defined())
this->out_csr_->UnpinMemory_();
if (this->coo_->defined())
this->coo_->UnpinMemory_();
if (this->in_csr_->defined()) this->in_csr_->UnpinMemory_();
if (this->out_csr_->defined()) this->out_csr_->UnpinMemory_();
if (this->coo_->defined()) this->coo_->UnpinMemory_();
}
void UnitGraph::RecordStream(DGLStreamHandle stream) {
if (this->in_csr_->defined())
this->in_csr_->RecordStream(stream);
if (this->out_csr_->defined())
this->out_csr_->RecordStream(stream);
if (this->coo_->defined())
this->coo_->RecordStream(stream);
if (this->in_csr_->defined()) this->in_csr_->RecordStream(stream);
if (this->out_csr_->defined()) this->out_csr_->RecordStream(stream);
if (this->coo_->defined()) this->coo_->RecordStream(stream);
this->recorded_streams.push_back(stream);
}
void UnitGraph::InvalidateCSR() {
this->out_csr_ = CSRPtr(new CSR());
}
void UnitGraph::InvalidateCSR() { this->out_csr_ = CSRPtr(new CSR()); }
void UnitGraph::InvalidateCSC() {
this->in_csr_ = CSRPtr(new CSR());
}
void UnitGraph::InvalidateCSC() { this->in_csr_ = CSRPtr(new CSR()); }
void UnitGraph::InvalidateCOO() {
this->coo_ = COOPtr(new COO());
}
void UnitGraph::InvalidateCOO() { this->coo_ = COOPtr(new COO()); }
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
UnitGraph::UnitGraph(
GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
dgl_format_code_t formats)
: BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
: BaseHeteroGraph(metagraph),
in_csr_(in_csr),
out_csr_(out_csr),
coo_(coo) {
if (!in_csr_) {
in_csr_ = CSRPtr(new CSR());
}
......@@ -1361,20 +1303,16 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c
formats_ = formats;
dgl_format_code_t created = GetCreatedFormats();
if ((formats | created) != formats)
LOG(FATAL) << "Graph created from formats: " << CodeToStr(created) <<
", which is not compatible with available formats: " << CodeToStr(formats);
LOG(FATAL) << "Graph created from formats: " << CodeToStr(created)
<< ", which is not compatible with available formats: "
<< CodeToStr(formats);
CHECK(GetAny()) << "At least one graph structure should exist.";
}
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
int num_vtypes,
const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo,
bool has_in_csr,
bool has_out_csr,
bool has_coo,
dgl_format_code_t formats) {
int num_vtypes, const aten::CSRMatrix& in_csr,
const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo, bool has_in_csr,
bool has_out_csr, bool has_coo, dgl_format_code_t formats) {
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr in_csr_ptr = nullptr;
......@@ -1394,21 +1332,21 @@ HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
else
coo_ptr = COOPtr(new COO());
return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
return HeteroGraphPtr(
new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats));
}
UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
if (inplace)
if (!(formats_ & CSC_CODE))
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot create CSC matrix.";
LOG(FATAL) << "The graph have restricted sparse format "
<< CodeToStr(formats_) << ", cannot create CSC matrix.";
CSRPtr ret = in_csr_;
// Prefers converting from COO since it is parallelized.
// TODO(BarclayII): need benchmarking.
if (!in_csr_->defined()) {
if (coo_->defined()) {
const auto& newadj = aten::COOToCSR(
aten::COOTranspose(coo_->adj()));
const auto& newadj = aten::COOToCSR(aten::COOTranspose(coo_->adj()));
if (inplace)
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
......@@ -1424,10 +1362,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
ret = std::make_shared<CSR>(meta_graph(), newadj);
}
if (inplace) {
if (IsPinned())
in_csr_->PinMemory_();
for (auto stream : recorded_streams)
in_csr_->RecordStream(stream);
if (IsPinned()) in_csr_->PinMemory_();
for (auto stream : recorded_streams) in_csr_->RecordStream(stream);
}
}
return ret;
......@@ -1437,8 +1373,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
if (inplace)
if (!(formats_ & CSR_CODE))
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot create CSR matrix.";
LOG(FATAL) << "The graph have restricted sparse format "
<< CodeToStr(formats_) << ", cannot create CSR matrix.";
CSRPtr ret = out_csr_;
// Prefers converting from COO since it is parallelized.
// TODO(BarclayII): need benchmarking.
......@@ -1460,10 +1396,8 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
ret = std::make_shared<CSR>(meta_graph(), newadj);
}
if (inplace) {
if (IsPinned())
out_csr_->PinMemory_();
for (auto stream : recorded_streams)
out_csr_->RecordStream(stream);
if (IsPinned()) out_csr_->PinMemory_();
for (auto stream : recorded_streams) out_csr_->RecordStream(stream);
}
}
return ret;
......@@ -1473,12 +1407,13 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
if (inplace)
if (!(formats_ & COO_CODE))
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot create COO matrix.";
LOG(FATAL) << "The graph have restricted sparse format "
<< CodeToStr(formats_) << ", cannot create COO matrix.";
COOPtr ret = coo_;
if (!coo_->defined()) {
if (in_csr_->defined()) {
const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
const auto& newadj =
aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
if (inplace)
*(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
......@@ -1494,10 +1429,8 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
ret = std::make_shared<COO>(meta_graph(), newadj);
}
if (inplace) {
if (IsPinned())
coo_->PinMemory_();
for (auto stream : recorded_streams)
coo_->RecordStream(stream);
if (IsPinned()) coo_->PinMemory_();
for (auto stream : recorded_streams) coo_->RecordStream(stream);
}
}
return ret;
......@@ -1527,18 +1460,13 @@ HeteroGraphPtr UnitGraph::GetAny() const {
dgl_format_code_t UnitGraph::GetCreatedFormats() const {
dgl_format_code_t ret = 0;
if (in_csr_->defined())
ret |= CSC_CODE;
if (out_csr_->defined())
ret |= CSR_CODE;
if (coo_->defined())
ret |= COO_CODE;
if (in_csr_->defined()) ret |= CSC_CODE;
if (out_csr_->defined()) ret |= CSR_CODE;
if (coo_->defined()) ret |= COO_CODE;
return ret;
}
dgl_format_code_t UnitGraph::GetAllowedFormats() const {
return formats_;
}
dgl_format_code_t UnitGraph::GetAllowedFormats() const { return formats_; }
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
switch (format) {
......@@ -1555,17 +1483,11 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
if (formats == ALL_CODE)
return HeteroGraphPtr(
// TODO(xiangsx) Make it as graph storage.Clone()
new UnitGraph(meta_graph_,
(in_csr_->defined())
? CSRPtr(new CSR(*in_csr_))
: nullptr,
(out_csr_->defined())
? CSRPtr(new CSR(*out_csr_))
: nullptr,
(coo_->defined())
? COOPtr(new COO(*coo_))
: nullptr,
formats));
new UnitGraph(
meta_graph_,
(in_csr_->defined()) ? CSRPtr(new CSR(*in_csr_)) : nullptr,
(out_csr_->defined()) ? CSRPtr(new CSR(*out_csr_)) : nullptr,
(coo_->defined()) ? COOPtr(new COO(*coo_)) : nullptr, formats));
int64_t num_vtypes = NumVertexTypes();
if (formats & COO_CODE)
return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats);
......@@ -1574,18 +1496,17 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const {
return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats);
}
SparseFormat UnitGraph::SelectFormat(dgl_format_code_t preferred_formats) const {
SparseFormat UnitGraph::SelectFormat(
dgl_format_code_t preferred_formats) const {
dgl_format_code_t common = preferred_formats & formats_;
dgl_format_code_t created = GetCreatedFormats();
if (common & created)
return DecodeFormat(common & created);
if (common & created) return DecodeFormat(common & created);
// NOTE(zihao): hypersparse is currently disabled since many CUDA operators on COO have
// not been implmented yet.
// if (coo_->defined() && coo_->IsHypersparse()) // only allow coo for hypersparse graph.
// NOTE(zihao): hypersparse is currently disabled since many CUDA operators on
// COO have not been implmented yet. if (coo_->defined() &&
// coo_->IsHypersparse()) // only allow coo for hypersparse graph.
// return SparseFormat::kCOO;
if (common)
return DecodeFormat(common);
if (common) return DecodeFormat(common);
return DecodeFormat(created);
}
......@@ -1623,13 +1544,15 @@ HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const {
}
case SparseFormat::kCSR: {
const aten::CSRMatrix csr = GetCSRMatrix(0);
const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
const aten::COOMatrix coo =
aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
return CreateFromCOO(1, coo);
}
case SparseFormat::kCSC: {
const aten::CSRMatrix csc = GetCSCMatrix(0);
const aten::CSRMatrix csr = aten::CSRTranspose(csc);
const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
const aten::COOMatrix coo =
aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking);
return CreateFromCOO(1, coo);
}
default:
......@@ -1675,8 +1598,8 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
formats_ = CSC_CODE;
break;
default:
LOG(FATAL) << "Load graph failed, formats code " << formats_code <<
"not recognized.";
LOG(FATAL) << "Load graph failed, formats code " << formats_code
<< "not recognized.";
}
}
......@@ -1708,13 +1631,12 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
return true;
}
void UnitGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_UnitGraphMagic);
// Didn't write UnitGraph::meta_graph_, since it's included in the underlying
// sparse matrix
auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)});
auto fstream = dynamic_cast<dgl::serialize::DGLStream *>(fs);
auto fstream = dynamic_cast<dgl::serialize::DGLStream*>(fs);
if (fstream) {
auto formats = fstream->FormatsToSave();
save_formats = formats == ANY_CODE
......@@ -1738,14 +1660,15 @@ UnitGraphPtr UnitGraph::Reverse() const {
CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
COOPtr new_coo = nullptr;
if (coo_->defined()) {
new_coo = COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
new_coo =
COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
}
return UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
return UnitGraphPtr(
new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
}
std::tuple<UnitGraphPtr, IdArray, IdArray>
UnitGraph::ToSimple() const {
std::tuple<UnitGraphPtr, IdArray, IdArray> UnitGraph::ToSimple() const {
CSRPtr new_incsr = nullptr, new_outcsr = nullptr;
COOPtr new_coo = nullptr;
IdArray count;
......@@ -1779,9 +1702,9 @@ UnitGraph::ToSimple() const {
break;
}
return std::make_tuple(UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
count,
edge_map);
return std::make_tuple(
UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)),
count, edge_map);
}
} // namespace dgl
......@@ -7,16 +7,17 @@
#ifndef DGL_GRAPH_UNIT_GRAPH_H_
#define DGL_GRAPH_UNIT_GRAPH_H_
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/lazy.h>
#include <dgl/array.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <utility>
#include <string>
#include <vector>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "../c_api_common.h"
......@@ -45,17 +46,11 @@ class UnitGraph : public BaseHeteroGraph {
typedef std::shared_ptr<COO> COOPtr;
typedef std::shared_ptr<CSR> CSRPtr;
inline dgl_type_t SrcType() const {
return 0;
}
inline dgl_type_t SrcType() const { return 0; }
inline dgl_type_t DstType() const {
return NumVertexTypes() == 1? 0 : 1;
}
inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; }
inline dgl_type_t EdgeType() const {
return 0;
}
inline dgl_type_t EdgeType() const { return 0; }
HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override {
LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. "
......@@ -75,9 +70,7 @@ class UnitGraph : public BaseHeteroGraph {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void Clear() override {
LOG(FATAL) << "UnitGraph graph is not mutable.";
}
void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; }
DGLDataType DataType() const override;
......@@ -89,9 +82,7 @@ class UnitGraph : public BaseHeteroGraph {
bool IsMultigraph() const override;
bool IsReadonly() const override {
return true;
}
bool IsReadonly() const override { return true; }
uint64_t NumVertices(dgl_type_t vtype) const override;
......@@ -108,9 +99,11 @@ class UnitGraph : public BaseHeteroGraph {
BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override;
bool HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;
bool HasEdgeBetween(
dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;
BoolArray HasEdgesBetween(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override;
BoolArray HasEdgesBetween(
dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override;
IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override;
......@@ -118,11 +111,13 @@ class UnitGraph : public BaseHeteroGraph {
IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override;
EdgeArray EdgeIdsAll(dgl_type_t etype, IdArray src, IdArray dst) const override;
EdgeArray EdgeIdsAll(
dgl_type_t etype, IdArray src, IdArray dst) const override;
IdArray EdgeIdsOne(dgl_type_t etype, IdArray src, IdArray dst) const override;
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override;
std::pair<dgl_id_t, dgl_id_t> FindEdge(
dgl_type_t etype, dgl_id_t eid) const override;
EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override;
......@@ -134,7 +129,8 @@ class UnitGraph : public BaseHeteroGraph {
EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override;
EdgeArray Edges(dgl_type_t etype, const std::string &order = "") const override;
EdgeArray Edges(
dgl_type_t etype, const std::string& order = "") const override;
uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override;
......@@ -156,18 +152,20 @@ class UnitGraph : public BaseHeteroGraph {
DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override;
std::vector<IdArray> GetAdj(
dgl_type_t etype, bool transpose, const std::string &fmt) const override;
dgl_type_t etype, bool transpose, const std::string& fmt) const override;
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override;
HeteroSubgraph VertexSubgraph(
const std::vector<IdArray>& vids) const override;
HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
const std::vector<IdArray>& eids,
bool preserve_nodes = false) const override;
// creators
/** @brief Create a graph with no edges */
static HeteroGraphPtr Empty(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
DGLDataType dtype, DGLContext ctx) {
int64_t num_vtypes, int64_t num_src, int64_t num_dst, DGLDataType dtype,
DGLContext ctx) {
IdArray row = IdArray::Empty({0}, dtype, ctx);
IdArray col = IdArray::Empty({0}, dtype, ctx);
return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
......@@ -175,9 +173,9 @@ class UnitGraph : public BaseHeteroGraph {
/** @brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, bool row_sorted = false,
bool col_sorted = false, dgl_format_code_t formats = ALL_CODE);
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray col, bool row_sorted = false, bool col_sorted = false,
dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
......@@ -185,9 +183,8 @@ class UnitGraph : public BaseHeteroGraph {
/** @brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = ALL_CODE);
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
......@@ -195,9 +192,8 @@ class UnitGraph : public BaseHeteroGraph {
/** @brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats = ALL_CODE);
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats = ALL_CODE);
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
......@@ -207,24 +203,22 @@ class UnitGraph : public BaseHeteroGraph {
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
/** @brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext &ctx);
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DGLContext& ctx);
/**
* @brief Pin the in_csr_, out_scr_ and coo_ of the current graph.
* @note The graph will be pinned inplace. Behavior depends on the current context,
* kDGLCPU: will be pinned;
* IsPinned: directly return;
* kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
* @note The graph will be pinned inplace. Behavior depends on the current
* context, kDGLCPU: will be pinned; IsPinned: directly return; kDGLCUDA:
* invalid, will throw an error. The context check is deferred to pinning the
* NDArray.
*/
void PinMemory_() override;
/**
* @brief Unpin the in_csr_, out_scr_ and coo_ of the current graph.
* @note The graph will be unpinned inplace. Behavior depends on the current context,
* IsPinned: will be unpinned;
* others: directly return.
* The context check is deferred to unpinning the NDArray.
* @note The graph will be unpinned inplace. Behavior depends on the current
* context, IsPinned: will be unpinned; others: directly return. The context
* check is deferred to unpinning the NDArray.
*/
void UnpinMemory_();
......@@ -236,24 +230,29 @@ class UnitGraph : public BaseHeteroGraph {
/**
* @brief Create in-edge CSR format of the unit graph.
* @param inplace if true and the in-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
* @return Return the in-edge CSR format. Create from other format if not exist.
* @param inplace if true and the in-edge CSR format does not exist, the
* created format will be cached in this object unless the format is
* restricted.
* @return Return the in-edge CSR format. Create from other format if not
* exist.
*/
CSRPtr GetInCSR(bool inplace = true) const;
/**
* @brief Create out-edge CSR format of the unit graph.
* @param inplace if true and the out-edge CSR format does not exist, the created
* format will be cached in this object unless the format is restricted.
* @return Return the out-edge CSR format. Create from other format if not exist.
* @param inplace if true and the out-edge CSR format does not exist, the
* created format will be cached in this object unless the format is
* restricted.
* @return Return the out-edge CSR format. Create from other format if not
* exist.
*/
CSRPtr GetOutCSR(bool inplace = true) const;
/**
* @brief Create COO format of the unit graph.
* @param inplace if true and the COO format does not exist, the created
* format will be cached in this object unless the format is restricted.
* format will be cached in this object unless the format is
* restricted.
* @return Return the COO format. Create from other format if not exist.
*/
COOPtr GetCOO(bool inplace = true) const;
......@@ -267,13 +266,14 @@ class UnitGraph : public BaseHeteroGraph {
/** @return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
SparseFormat SelectFormat(
dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
return SelectFormat(preferred_formats);
}
/**
* @brief Return the graph in the given format. Perform format conversion if the
* requested format does not exist.
* @brief Return the graph in the given format. Perform format conversion if
* the requested format does not exist.
*
* @return A graph in the requested format.
*/
......@@ -298,11 +298,11 @@ class UnitGraph : public BaseHeteroGraph {
UnitGraphPtr Reverse() const;
/** @return the simpled (no-multi-edge) graph
* the count recording the number of duplicated edges from the original graph.
* the edge mapping from the edge IDs of original graph to those of the
* returned graph.
* the count recording the number of duplicated edges from the
* original graph. the edge mapping from the edge IDs of original graph to
* those of the returned graph.
*/
std::tuple<UnitGraphPtr, IdArray, IdArray>ToSimple() const;
std::tuple<UnitGraphPtr, IdArray, IdArray> ToSimple() const;
void InvalidateCSR();
......@@ -326,7 +326,8 @@ class UnitGraph : public BaseHeteroGraph {
* @param out_csr out edge csr
* @param coo coo
*/
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
UnitGraph(
GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
dgl_format_code_t formats = ALL_CODE);
/**
......@@ -341,13 +342,9 @@ class UnitGraph : public BaseHeteroGraph {
* @param has_coo whether coo is valid
*/
static HeteroGraphPtr CreateUnitGraphFrom(
int num_vtypes,
const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo,
bool has_in_csr,
bool has_out_csr,
bool has_coo,
int num_vtypes, const aten::CSRMatrix& in_csr,
const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo,
bool has_in_csr, bool has_out_csr, bool has_coo,
dgl_format_code_t formats = ALL_CODE);
/** @return Return any existing format. */
......@@ -356,8 +353,8 @@ class UnitGraph : public BaseHeteroGraph {
/**
* @brief Determine which format to use with a preference.
*
* If the storage of unit graph is "locked", i.e. no conversion is allowed, then
* it will return the locked format.
* If the storage of unit graph is "locked", i.e. no conversion is allowed,
* then it will return the locked format.
*
* Otherwise, it will return whatever DGL thinks is the most appropriate given
* the arguments.
......@@ -369,8 +366,8 @@ class UnitGraph : public BaseHeteroGraph {
GraphPtr AsImmutableGraph() const override;
// Graph stored in different format. We use an on-demand strategy: the format is
// only materialized if the operation that suitable for it is invoked.
// Graph stored in different format. We use an on-demand strategy: the format
// is only materialized if the operation that suitable for it is invoked.
/** @brief CSR graph that stores reverse edges */
CSRPtr in_csr_;
/** @brief CSR representation */
......
......@@ -6,8 +6,10 @@
#include <dgl/array.h>
#include <dgl/random.h>
#include <numeric>
#include <vector>
#include "sample_utils.h"
namespace dgl {
......@@ -27,8 +29,8 @@ template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray);
template <typename IdxType, typename FloatType>
void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out,
bool replace) {
void RandomEngine::Choice(
IdxType num, FloatArray prob, IdxType* out, bool replace) {
const IdxType N = prob->shape[0];
if (!replace)
CHECK_LE(num, N)
......@@ -45,30 +47,26 @@ void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out,
delete sampler;
}
template void RandomEngine::Choice<int32_t, float>(int32_t num, FloatArray prob,
int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>(int64_t num, FloatArray prob,
int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>(int32_t num,
FloatArray prob,
int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, double>(int64_t num,
FloatArray prob,
int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, int8_t>(int32_t num, FloatArray prob,
int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, int8_t>(int64_t num, FloatArray prob,
int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, uint8_t>(int32_t num,
FloatArray prob,
int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, uint8_t>(int64_t num,
FloatArray prob,
int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, float>(
int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, float>(
int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, double>(
int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, double>(
int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, int8_t>(
int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, int8_t>(
int64_t num, FloatArray prob, int64_t* out, bool replace);
template void RandomEngine::Choice<int32_t, uint8_t>(
int32_t num, FloatArray prob, int32_t* out, bool replace);
template void RandomEngine::Choice<int64_t, uint8_t>(
int64_t num, FloatArray prob, int64_t* out, bool replace);
template <typename IdxType>
void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
bool replace) {
void RandomEngine::UniformChoice(
IdxType num, IdxType population, IdxType* out, bool replace) {
CHECK_GE(num, 0) << "The numbers to sample should be non-negative.";
CHECK_GE(population, 0) << "The population size should be non-negative.";
if (!replace)
......@@ -112,14 +110,15 @@ void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
}
} else {
// In this case, `num >= population / 10`. To reduce the computation overhead,
// we should reduce the number of random number generations. Even though
// reservior algorithm is more memory effficient (it has O(num) memory complexity),
// it generates O(population) random numbers, which is computationally expensive.
// This algorithm has memory complexity of O(population) but generates much fewer random
// numbers O(num). In the case of `num >= population/10`, we don't need to worry about
// memory complexity because `num` is usually small. So is `population`. Allocating a small
// piece of memory is very efficient.
// In this case, `num >= population / 10`. To reduce the computation
// overhead, we should reduce the number of random number generations.
// Even though reservior algorithm is more memory effficient (it has
// O(num) memory complexity), it generates O(population) random numbers,
// which is computationally expensive. This algorithm has memory
// complexity of O(population) but generates much fewer random numbers
// O(num). In the case of `num >= population/10`, we don't need to worry
// about memory complexity because `num` is usually small. So is
// `population`. Allocating a small piece of memory is very efficient.
std::vector<IdxType> seq(population);
for (size_t i = 0; i < seq.size(); i++) seq[i] = i;
for (IdxType i = 0; i < num; i++) {
......@@ -134,23 +133,22 @@ void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out,
}
}
template void RandomEngine::UniformChoice<int32_t>(int32_t num,
int32_t population,
int32_t* out, bool replace);
template void RandomEngine::UniformChoice<int64_t>(int64_t num,
int64_t population,
int64_t* out, bool replace);
template void RandomEngine::UniformChoice<int32_t>(
int32_t num, int32_t population, int32_t* out, bool replace);
template void RandomEngine::UniformChoice<int64_t>(
int64_t num, int64_t population, int64_t* out, bool replace);
template <typename IdxType, typename FloatType>
void RandomEngine::BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, IdxType* out, bool replace) {
IdxType num, const IdxType* split, FloatArray bias, IdxType* out,
bool replace) {
const int64_t num_tags = bias->shape[0];
const FloatType *bias_data = static_cast<FloatType *>(bias->data);
const FloatType* bias_data = static_cast<FloatType*>(bias->data);
IdxType total_node_num = 0;
FloatArray prob = NDArray::Empty({num_tags}, bias->dtype, bias->ctx);
FloatType *prob_data = static_cast<FloatType *>(prob->data);
for (int64_t tag = 0 ; tag < num_tags; ++tag) {
int64_t tag_num_nodes = split[tag+1] - split[tag];
FloatType* prob_data = static_cast<FloatType*>(prob->data);
for (int64_t tag = 0; tag < num_tags; ++tag) {
int64_t tag_num_nodes = split[tag + 1] - split[tag];
total_node_num += tag_num_nodes;
FloatType tag_bias = bias_data[tag];
prob_data[tag] = tag_num_nodes * tag_bias;
......@@ -159,19 +157,21 @@ void RandomEngine::BiasedChoice(
auto sampler = utils::TreeSampler<IdxType, FloatType, true>(this, prob);
for (IdxType i = 0; i < num; ++i) {
const int64_t tag = sampler.Draw();
const IdxType tag_num_nodes = split[tag+1] - split[tag];
const IdxType tag_num_nodes = split[tag + 1] - split[tag];
out[i] = RandInt(tag_num_nodes) + split[tag];
}
} else {
utils::TreeSampler<int64_t, FloatType, false> sampler(this, prob, bias_data);
utils::TreeSampler<int64_t, FloatType, false> sampler(
this, prob, bias_data);
CHECK_GE(total_node_num, num)
<< "Cannot take more sample than population when 'replace=false'";
// we use hash set here. Maybe in the future we should support reservoir algorithm
// we use hash set here. Maybe in the future we should support reservoir
// algorithm
std::vector<std::unordered_set<IdxType>> selected(num_tags);
for (IdxType i = 0 ; i < num ; ++i) {
for (IdxType i = 0; i < num; ++i) {
const int64_t tag = sampler.Draw();
bool inserted = false;
const IdxType tag_num_nodes = split[tag+1] - split[tag];
const IdxType tag_num_nodes = split[tag + 1] - split[tag];
IdxType selected_node;
while (!inserted) {
CHECK_LT(selected[tag].size(), tag_num_nodes)
......
......@@ -6,15 +6,16 @@
#ifndef DGL_RANDOM_CPU_SAMPLE_UTILS_H_
#define DGL_RANDOM_CPU_SAMPLE_UTILS_H_
#include <dgl/random.h>
#include <dgl/array.h>
#include <dgl/random.h>
#include <algorithm>
#include <utility>
#include <queue>
#include <cstdlib>
#include <cmath>
#include <numeric>
#include <cstdlib>
#include <limits>
#include <numeric>
#include <queue>
#include <utility>
#include <vector>
namespace dgl {
......@@ -32,23 +33,21 @@ class BaseSampler {
}
};
// (BarclayII 2022.9.20) Changing the internal data type of probabilities to double since
// we are using non-uniform sampling to sample on boolean masks, where False represents
// probability 0. DType could be uint8 in this case, which will give incorrect arithmetic
// results due to overflowing and/or integer division.
// (BarclayII 2022.9.20) Changing the internal data type of probabilities to
// double since we are using non-uniform sampling to sample on boolean masks,
// where False represents probability 0. DType could be uint8 in this case,
// which will give incorrect arithmetic results due to overflowing and/or
// integer division.
/**
* AliasSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: Alias Method(https://en.wikipedia.org/wiki/Alias_method)
* Sampler building complexity: O(n)
* Sample w/ replacement complexity: O(1)
* Sample w/o replacement complexity: O(log n)
* AliasSampler is used to sample elements from a given discrete categorical
* distribution. Algorithm: Alias
* Method(https://en.wikipedia.org/wiki/Alias_method) Sampler building
* complexity: O(n) Sample w/ replacement complexity: O(1) Sample w/o
* replacement complexity: O(log n)
*/
template <
typename Idx,
typename DType,
bool replace>
class AliasSampler: public BaseSampler<Idx> {
template <typename Idx, typename DType, bool replace>
class AliasSampler : public BaseSampler<Idx> {
private:
RandomEngine *re;
Idx N;
......@@ -56,7 +55,8 @@ class AliasSampler: public BaseSampler<Idx> {
std::vector<Idx> K; // alias table
std::vector<double> U; // probability table
FloatArray _prob; // category distribution
std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<bool>
used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // index mapping, activated when replace=false;
inline Idx Map(Idx x) const { // Map consecutive indices to unused elements
......@@ -72,16 +72,16 @@ class AliasSampler: public BaseSampler<Idx> {
N = 0;
accum = 0.;
taken = 0.;
if (!replace)
id_mapping.clear();
if (!replace) id_mapping.clear();
for (Idx i = 0; i < prob_size; ++i)
if (!used[i]) {
N++;
accum += prob_data[i];
if (!replace)
id_mapping.push_back(i);
if (!replace) id_mapping.push_back(i);
}
if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
if (N == 0)
LOG(FATAL)
<< "Cannot take more sample than population when 'replace=false'";
K.resize(N);
U.resize(N);
double avg = accum / static_cast<double>(N);
......@@ -113,13 +113,12 @@ class AliasSampler: public BaseSampler<Idx> {
public:
void ResetState(FloatArray prob) {
used.resize(prob->shape[0]);
if (!replace)
_prob = prob;
if (!replace) _prob = prob;
std::fill(used.begin(), used.end(), false);
Reconstruct(prob);
}
explicit AliasSampler(RandomEngine* re, FloatArray prob): re(re) {
explicit AliasSampler(RandomEngine *re, FloatArray prob) : re(re) {
ResetState(prob);
}
......@@ -128,11 +127,10 @@ class AliasSampler: public BaseSampler<Idx> {
Idx Draw() {
if (!replace) {
const DType *_prob_data = _prob.Ptr<DType>();
if (2 * taken >= accum)
Reconstruct(_prob);
if (accum <= 0)
return -1;
// accum changes after Reconstruct(), so avg should be computed after that.
if (2 * taken >= accum) Reconstruct(_prob);
if (accum <= 0) return -1;
// accum changes after Reconstruct(), so avg should be computed after
// that.
double avg = accum / N;
while (true) {
double dice = re->Uniform<double>(0, N);
......@@ -151,8 +149,7 @@ class AliasSampler: public BaseSampler<Idx> {
}
}
}
if (accum <= 0)
return -1;
if (accum <= 0) return -1;
double avg = accum / N;
double dice = re->Uniform<double>(0, N);
Idx i = static_cast<Idx>(dice);
......@@ -164,27 +161,26 @@ class AliasSampler: public BaseSampler<Idx> {
}
};
/**
* CDFSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: create a cumulative distribution function and conduct binary search for sampling.
* Reference: https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
* CDFSampler is used to sample elements from a given discrete categorical
* distribution. Algorithm: create a cumulative distribution function and
* conduct binary search for sampling. Reference:
* https://github.com/numpy/numpy/blob/d37908/numpy/random/mtrand.pyx#L804
* Sampler building complexity: O(n)
* Sample w/ and w/o replacement complexity: O(log n)
*/
template <
typename Idx,
typename DType,
bool replace>
class CDFSampler: public BaseSampler<Idx> {
template <typename Idx, typename DType, bool replace>
class CDFSampler : public BaseSampler<Idx> {
private:
RandomEngine *re;
Idx N;
double accum, taken;
FloatArray _prob; // categorical distribution
std::vector<double> cdf; // cumulative distribution function
std::vector<bool> used; // indicate availability, activated when replace=false;
std::vector<Idx> id_mapping; // indicate index mapping, activated when replace=false;
std::vector<bool>
used; // indicate availability, activated when replace=false;
std::vector<Idx>
id_mapping; // indicate index mapping, activated when replace=false;
inline Idx Map(Idx x) const { // Map consecutive indices to unused elements
if (replace)
......@@ -199,31 +195,30 @@ class CDFSampler: public BaseSampler<Idx> {
N = 0;
accum = 0.;
taken = 0.;
if (!replace)
id_mapping.clear();
if (!replace) id_mapping.clear();
cdf.clear();
cdf.push_back(0);
for (Idx i = 0; i < prob_size; ++i)
if (!used[i]) {
N++;
accum += prob_data[i];
if (!replace)
id_mapping.push_back(i);
if (!replace) id_mapping.push_back(i);
cdf.push_back(accum);
}
if (N == 0) LOG(FATAL) << "Cannot take more sample than population when 'replace=false'";
if (N == 0)
LOG(FATAL)
<< "Cannot take more sample than population when 'replace=false'";
}
public:
void ResetState(FloatArray prob) {
used.resize(prob->shape[0]);
if (!replace)
_prob = prob;
if (!replace) _prob = prob;
std::fill(used.begin(), used.end(), false);
Reconstruct(prob);
}
explicit CDFSampler(RandomEngine *re, FloatArray prob): re(re) {
explicit CDFSampler(RandomEngine *re, FloatArray prob) : re(re) {
ResetState(prob);
}
......@@ -233,13 +228,12 @@ class CDFSampler: public BaseSampler<Idx> {
double eps = std::numeric_limits<double>::min();
if (!replace) {
const DType *_prob_data = _prob.Ptr<DType>();
if (2 * taken >= accum)
Reconstruct(_prob);
if (accum <= 0)
return -1;
if (2 * taken >= accum) Reconstruct(_prob);
if (accum <= 0) return -1;
while (true) {
double p = std::max(re->Uniform<double>(0., accum), eps);
Idx rst = Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
Idx rst =
Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
double cap = static_cast<double>(_prob_data[rst]);
if (!used[rst]) {
used[rst] = true;
......@@ -248,26 +242,21 @@ class CDFSampler: public BaseSampler<Idx> {
}
}
}
if (accum <= 0)
return -1;
if (accum <= 0) return -1;
double p = std::max(re->Uniform<double>(0., accum), eps);
return Map(std::lower_bound(cdf.begin(), cdf.end(), p) - cdf.begin() - 1);
}
};
/**
* TreeSampler is used to sample elements from a given discrete categorical distribution.
* Algorithm: create a heap that stores accumulated likelihood of its leaf descendents.
* Reference: https://blog.smola.org/post/1016514759
* TreeSampler is used to sample elements from a given discrete categorical
* distribution. Algorithm: create a heap that stores accumulated likelihood of
* its leaf descendents. Reference: https://blog.smola.org/post/1016514759
* Sampler building complexity: O(n)
* Sample w/ and w/o replacement complexity: O(log n)
*/
template <
typename Idx,
typename DType,
bool replace>
class TreeSampler: public BaseSampler<Idx> {
template <typename Idx, typename DType, bool replace>
class TreeSampler : public BaseSampler<Idx> {
private:
RandomEngine *re;
std::vector<double> weight; // accumulated likelihood of subtrees.
......@@ -286,11 +275,11 @@ class TreeSampler: public BaseSampler<Idx> {
weight[i] = weight[i * 2] + weight[i * 2 + 1];
}
explicit TreeSampler(RandomEngine *re, FloatArray prob, const DType* decrease = nullptr)
explicit TreeSampler(
RandomEngine *re, FloatArray prob, const DType *decrease = nullptr)
: re(re), decrease(decrease) {
num_leafs = 1;
while (num_leafs < prob->shape[0])
num_leafs *= 2;
while (num_leafs < prob->shape[0]) num_leafs *= 2;
N = num_leafs * 2;
weight.resize(N);
ResetState(prob);
......@@ -298,18 +287,18 @@ class TreeSampler: public BaseSampler<Idx> {
/* Pick an element from the given distribution and update the tree.
*
* The parameter decrease is an array of which the length is the number of categories.
* Every time an element in the category x is picked, the weight of this category is subtracted
* by decrease[x]. It is used to support the case where a category might contains multiple
* candidates and decrease[x] is the weight of one candidate of the category x.
* The parameter decrease is an array of which the length is the number of
* categories. Every time an element in the category x is picked, the weight
* of this category is subtracted by decrease[x]. It is used to support the
* case where a category might contains multiple candidates and decrease[x] is
* the weight of one candidate of the category x.
*
* When decrease == nullptr, it means there is only one candidate in each category and will
* directly set the weight of the chosen category as 0.
* When decrease == nullptr, it means there is only one candidate in each
* category and will directly set the weight of the chosen category as 0.
*
*/
Idx Draw() {
if (weight[1] <= 0)
return -1;
if (weight[1] <= 0) return -1;
int64_t cur = 1;
double p = re->Uniform<double>(0, weight[cur]);
double accum = 0.;
......@@ -319,15 +308,16 @@ class TreeSampler: public BaseSampler<Idx> {
// w_r > 0 can suppress some numerical problems.
Idx shift = static_cast<Idx>(p > pivot && w_r > 0);
cur = cur * 2 + shift;
if (shift == 1)
accum = pivot;
if (shift == 1) accum = pivot;
}
Idx rst = cur - num_leafs;
if (!replace) {
while (cur >= 1) {
if (cur >= num_leafs)
weight[cur] = this->decrease ?
weight[cur] - static_cast<double>(this->decrease[rst]) : 0.;
weight[cur] =
this->decrease
? weight[cur] - static_cast<double>(this->decrease[rst])
: 0.;
else
weight[cur] = weight[cur * 2] + weight[cur * 2 + 1];
cur /= 2;
......
......@@ -3,14 +3,15 @@
* @file ndarray.cc
* @brief NDArray container infratructure.
*/
#include <string.h>
#include <dmlc/logging.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/shared_mem.h>
#include <dgl/zerocopy_serializer.h>
#include <dgl/runtime/tensordispatch.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/logging.h>
#include <string.h>
#include "runtime_base.h"
namespace dgl {
......@@ -66,16 +67,15 @@ void NDArray::Internal::DefaultDeleter(NDArray::Container* ptr) {
ptr->mem = nullptr;
} else if (ptr->dl_tensor.data != nullptr) {
// if the array is still pinned before freeing, unpin it.
if (ptr->pinned_by_dgl_)
UnpinContainer(ptr);
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace(
ptr->dl_tensor.ctx, ptr->dl_tensor.data);
if (ptr->pinned_by_dgl_) UnpinContainer(ptr);
dgl::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)
->FreeDataSpace(ptr->dl_tensor.ctx, ptr->dl_tensor.data);
}
delete ptr;
}
NDArray NDArray::Internal::Create(std::vector<int64_t> shape,
DGLDataType dtype, DGLContext ctx) {
NDArray NDArray::Internal::Create(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {
VerifyDataType(dtype);
// critical zone
NDArray::Container* data = new NDArray::Container();
......@@ -91,7 +91,7 @@ NDArray NDArray::Internal::Create(std::vector<int64_t> shape,
// does not support NULL stride and thus will crash the program).
data->stride_.resize(data->dl_tensor.ndim, 1);
for (int i = data->dl_tensor.ndim - 2; i >= 0; --i) {
data->stride_[i] = data->shape_[i+1] * data->stride_[i+1];
data->stride_[i] = data->shape_[i + 1] * data->stride_[i + 1];
}
data->dl_tensor.strides = dmlc::BeginPtr(data->stride_);
// setup dtype
......@@ -108,13 +108,10 @@ DGLArray* NDArray::Internal::MoveAsDGLArray(NDArray arr) {
return tensor;
}
size_t NDArray::GetSize() const {
return GetDataSize(data_->dl_tensor);
}
size_t NDArray::GetSize() const { return GetDataSize(data_->dl_tensor); }
int64_t NDArray::NumElements() const {
if (data_->dl_tensor.ndim == 0)
return 0;
if (data_->dl_tensor.ndim == 0) return 0;
int64_t size = 1;
for (int i = 0; i < data_->dl_tensor.ndim; ++i) {
size *= data_->dl_tensor.shape[i];
......@@ -124,10 +121,10 @@ int64_t NDArray::NumElements() const {
bool NDArray::IsContiguous() const {
CHECK(data_ != nullptr);
if (data_->dl_tensor.strides == nullptr)
return true;
if (data_->dl_tensor.strides == nullptr) return true;
// See https://github.com/dmlc/dgl/issues/2118 and PyTorch's compute_contiguous() implementation
// See https://github.com/dmlc/dgl/issues/2118 and PyTorch's
// compute_contiguous() implementation
int64_t z = 1;
for (int64_t i = data_->dl_tensor.ndim - 1; i >= 0; --i) {
if (data_->dl_tensor.shape[i] != 1) {
......@@ -140,14 +137,12 @@ bool NDArray::IsContiguous() const {
return true;
}
NDArray NDArray::CreateView(std::vector<int64_t> shape,
DGLDataType dtype,
int64_t offset) {
NDArray NDArray::CreateView(
std::vector<int64_t> shape, DGLDataType dtype, int64_t offset) {
CHECK(data_ != nullptr);
CHECK(IsContiguous()) << "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, data_->dl_tensor.ctx);
ret.data_->dl_tensor.byte_offset =
this->data_->dl_tensor.byte_offset;
ret.data_->dl_tensor.byte_offset = this->data_->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->data_->dl_tensor);
size_t view_size = GetDataSize(ret.data_->dl_tensor);
CHECK_LE(view_size, curr_size)
......@@ -160,9 +155,8 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape,
return ret;
}
NDArray NDArray::EmptyShared(const std::string &name,
std::vector<int64_t> shape,
DGLDataType dtype,
NDArray NDArray::EmptyShared(
const std::string& name, std::vector<int64_t> shape, DGLDataType dtype,
DGLContext ctx, bool is_create) {
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
......@@ -178,30 +172,27 @@ NDArray NDArray::EmptyShared(const std::string &name,
return ret;
}
NDArray NDArray::Empty(std::vector<int64_t> shape,
DGLDataType dtype,
DGLContext ctx) {
NDArray NDArray::Empty(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx) {
NDArray ret = Internal::Create(shape, dtype, ctx);
// setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
if (size > 0)
ret.data_->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret.data_->dl_tensor.data = DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype);
return ret;
}
void NDArray::CopyFromTo(DGLArray* from,
DGLArray* to) {
void NDArray::CopyFromTo(DGLArray* from, DGLArray* to) {
size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to);
CHECK_EQ(from_size, to_size)
<< "DGLArrayCopyFromTo: The size must exactly match";
CHECK(from->ctx.device_type == to->ctx.device_type
|| from->ctx.device_type == kDGLCPU
|| to->ctx.device_type == kDGLCPU)
CHECK(
from->ctx.device_type == to->ctx.device_type ||
from->ctx.device_type == kDGLCPU || to->ctx.device_type == kDGLCPU)
<< "Can not copy across different ctx types directly";
// Use the context that is *not* a cpu context to get the correct device
......@@ -210,9 +201,9 @@ void NDArray::CopyFromTo(DGLArray* from,
// default: local current cuda stream
DeviceAPI::Get(ctx)->CopyDataFromTo(
from->data, static_cast<size_t>(from->byte_offset),
to->data, static_cast<size_t>(to->byte_offset),
from_size, from->ctx, to->ctx, from->dtype);
from->data, static_cast<size_t>(from->byte_offset), to->data,
static_cast<size_t>(to->byte_offset), from_size, from->ctx, to->ctx,
from->dtype);
}
void NDArray::PinContainer(NDArray::Container* ptr) {
......@@ -239,65 +230,61 @@ void NDArray::UnpinContainer(NDArray::Container* ptr) {
void NDArray::RecordStream(DGLArray* tensor, DGLStreamHandle stream) {
TensorDispatcher* td = TensorDispatcher::Global();
CHECK(td->IsAvailable()) << "RecordStream only works when TensorAdaptor is available.";
CHECK(td->IsAvailable())
<< "RecordStream only works when TensorAdaptor is available.";
CHECK_EQ(tensor->ctx.device_type, kDGLCUDA)
<< "RecordStream only works with GPU tensors.";
td->RecordStream(tensor->data, stream, tensor->ctx.device_id);
}
template<typename T>
template <typename T>
NDArray NDArray::FromVector(const std::vector<T>& vec, DGLContext ctx) {
const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;
int64_t size = static_cast<int64_t>(vec.size());
NDArray ret = NDArray::Empty({size}, dtype, ctx);
DeviceAPI::Get(ctx)->CopyDataFromTo(
vec.data(),
0,
static_cast<T*>(ret->data),
0,
size * sizeof(T),
DGLContext{kDGLCPU, 0},
ctx,
dtype);
vec.data(), 0, static_cast<T*>(ret->data), 0, size * sizeof(T),
DGLContext{kDGLCPU, 0}, ctx, dtype);
return ret;
}
NDArray NDArray::CreateFromRaw(const std::vector<int64_t>& shape,
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free) {
NDArray NDArray::CreateFromRaw(
const std::vector<int64_t>& shape, DGLDataType dtype, DGLContext ctx,
void* raw, bool auto_free) {
NDArray ret = Internal::Create(shape, dtype, ctx);
ret.data_->dl_tensor.data = raw;
if (!auto_free)
ret.data_->deleter = nullptr;
if (!auto_free) ret.data_->deleter = nullptr;
return ret;
}
// export specializations
template NDArray NDArray::FromVector<int32_t>(const std::vector<int32_t>&, DGLContext);
template NDArray NDArray::FromVector<int64_t>(const std::vector<int64_t>&, DGLContext);
template NDArray NDArray::FromVector<uint32_t>(const std::vector<uint32_t>&, DGLContext);
template NDArray NDArray::FromVector<uint64_t>(const std::vector<uint64_t>&, DGLContext);
template NDArray NDArray::FromVector<float>(const std::vector<float>&, DGLContext);
template NDArray NDArray::FromVector<double>(const std::vector<double>&, DGLContext);
template<typename T>
template NDArray NDArray::FromVector<int32_t>(
const std::vector<int32_t>&, DGLContext);
template NDArray NDArray::FromVector<int64_t>(
const std::vector<int64_t>&, DGLContext);
template NDArray NDArray::FromVector<uint32_t>(
const std::vector<uint32_t>&, DGLContext);
template NDArray NDArray::FromVector<uint64_t>(
const std::vector<uint64_t>&, DGLContext);
template NDArray NDArray::FromVector<float>(
const std::vector<float>&, DGLContext);
template NDArray NDArray::FromVector<double>(
const std::vector<double>&, DGLContext);
template <typename T>
std::vector<T> NDArray::ToVector() const {
const DGLDataType dtype = DGLDataTypeTraits<T>::dtype;
CHECK(data_->dl_tensor.ndim == 1) << "ToVector() only supported for 1D arrays";
CHECK(data_->dl_tensor.ndim == 1)
<< "ToVector() only supported for 1D arrays";
CHECK(data_->dl_tensor.dtype == dtype) << "dtype mismatch";
int64_t size = data_->dl_tensor.shape[0];
std::vector<T> vec(size);
const DGLContext &ctx = data_->dl_tensor.ctx;
const DGLContext& ctx = data_->dl_tensor.ctx;
DeviceAPI::Get(ctx)->CopyDataFromTo(
static_cast<T*>(data_->dl_tensor.data),
0,
vec.data(),
0,
size * sizeof(T),
ctx,
DGLContext{kDGLCPU, 0},
dtype);
static_cast<T*>(data_->dl_tensor.data), 0, vec.data(), 0,
size * sizeof(T), ctx, DGLContext{kDGLCPU, 0}, dtype);
return vec;
}
......@@ -313,13 +300,12 @@ std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {
}
bool NDArray::IsContainerPinned(NDArray::Container* ptr) {
if (ptr->pinned_by_dgl_)
return true;
if (ptr->pinned_by_dgl_) return true;
auto* tensor = &(ptr->dl_tensor);
// Can only be pinned if on CPU...
if (tensor->ctx.device_type != kDGLCPU)
return false;
// ... and CUDA device API is enabled, and the tensor is indeed in pinned memory.
if (tensor->ctx.device_type != kDGLCPU) return false;
// ... and CUDA device API is enabled, and the tensor is indeed in pinned
// memory.
auto device = DeviceAPI::Get(kDGLCUDA, true);
return device && device->IsPinned(tensor->data);
}
......@@ -340,27 +326,20 @@ bool NDArray::Load(dmlc::Stream* strm) {
return true;
}
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DGLArray file format";
CHECK(strm->Read(&reserved))
<< "Invalid DGLArray file format";
CHECK(header == kDGLNDArrayMagic)
<< "Invalid DGLArray file format";
CHECK(strm->Read(&header)) << "Invalid DGLArray file format";
CHECK(strm->Read(&reserved)) << "Invalid DGLArray file format";
CHECK(header == kDGLNDArrayMagic) << "Invalid DGLArray file format";
DGLContext ctx;
int ndim;
DGLDataType dtype;
CHECK(strm->Read(&ctx))
<< "Invalid DGLArray file format";
CHECK(strm->Read(&ndim))
<< "Invalid DGLArray file format";
CHECK(strm->Read(&dtype))
<< "Invalid DGLArray file format";
CHECK(strm->Read(&ctx)) << "Invalid DGLArray file format";
CHECK(strm->Read(&ndim)) << "Invalid DGLArray file format";
CHECK(strm->Read(&dtype)) << "Invalid DGLArray file format";
CHECK_EQ(ctx.device_type, kDGLCPU)
<< "Invalid DGLArray context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DGLArray file format";
CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DGLArray file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
......@@ -369,8 +348,7 @@ bool NDArray::Load(dmlc::Stream* strm) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DGLArray file format";
CHECK(strm->Read(&data_byte_size)) << "Invalid DGLArray file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DGLArray file format";
if (data_byte_size != 0) {
......@@ -386,20 +364,14 @@ bool NDArray::Load(dmlc::Stream* strm) {
return true;
}
} // namespace runtime
} // namespace dgl
using namespace dgl::runtime;
int DGLArrayAlloc(const dgl_index_t* shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
DGLArrayHandle* out) {
int DGLArrayAlloc(
const dgl_index_t* shape, int ndim, int dtype_code, int dtype_bits,
int dtype_lanes, int device_type, int device_id, DGLArrayHandle* out) {
API_BEGIN();
DGLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code);
......@@ -413,22 +385,17 @@ int DGLArrayAlloc(const dgl_index_t* shape,
API_END();
}
int DGLArrayAllocSharedMem(const char *mem_name,
const dgl_index_t *shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
bool is_create,
DGLArrayHandle* out) {
int DGLArrayAllocSharedMem(
const char* mem_name, const dgl_index_t* shape, int ndim, int dtype_code,
int dtype_bits, int dtype_lanes, bool is_create, DGLArrayHandle* out) {
API_BEGIN();
DGLDataType dtype;
dtype.code = static_cast<uint8_t>(dtype_code);
dtype.bits = static_cast<uint8_t>(dtype_bits);
dtype.lanes = static_cast<uint16_t>(dtype_lanes);
std::vector<int64_t> shape_vec(shape, shape + ndim);
NDArray arr = NDArray::EmptyShared(mem_name, shape_vec, dtype,
DGLContext{kDGLCPU, 0}, is_create);
NDArray arr = NDArray::EmptyShared(
mem_name, shape_vec, dtype, DGLContext{kDGLCPU, 0}, is_create);
*out = NDArray::Internal::MoveAsDGLArray(arr);
API_END();
}
......@@ -439,57 +406,48 @@ int DGLArrayFree(DGLArrayHandle handle) {
API_END();
}
int DGLArrayCopyFromTo(DGLArrayHandle from,
DGLArrayHandle to) {
int DGLArrayCopyFromTo(DGLArrayHandle from, DGLArrayHandle to) {
API_BEGIN();
NDArray::CopyFromTo(from, to);
API_END();
}
int DGLArrayCopyFromBytes(DGLArrayHandle handle,
void* data,
size_t nbytes) {
int DGLArrayCopyFromBytes(DGLArrayHandle handle, void* data, size_t nbytes) {
API_BEGIN();
DGLContext cpu_ctx;
cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes)
<< "DGLArrayCopyFromBytes: size mismatch";
DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
data, 0,
handle->data, static_cast<size_t>(handle->byte_offset),
CHECK_EQ(arr_size, nbytes) << "DGLArrayCopyFromBytes: size mismatch";
DeviceAPI::Get(handle->ctx)
->CopyDataFromTo(
data, 0, handle->data, static_cast<size_t>(handle->byte_offset),
nbytes, cpu_ctx, handle->ctx, handle->dtype);
API_END();
}
int DGLArrayCopyToBytes(DGLArrayHandle handle,
void* data,
size_t nbytes) {
int DGLArrayCopyToBytes(DGLArrayHandle handle, void* data, size_t nbytes) {
API_BEGIN();
DGLContext cpu_ctx;
cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(*handle);
CHECK_EQ(arr_size, nbytes)
<< "DGLArrayCopyToBytes: size mismatch";
DeviceAPI::Get(handle->ctx)->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset),
data, 0,
CHECK_EQ(arr_size, nbytes) << "DGLArrayCopyToBytes: size mismatch";
DeviceAPI::Get(handle->ctx)
->CopyDataFromTo(
handle->data, static_cast<size_t>(handle->byte_offset), data, 0,
nbytes, handle->ctx, cpu_ctx, handle->dtype);
API_END();
}
int DGLArrayPinData(DGLArrayHandle handle,
DGLContext ctx) {
int DGLArrayPinData(DGLArrayHandle handle, DGLContext ctx) {
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::PinContainer(nd_container);
API_END();
}
int DGLArrayUnpinData(DGLArrayHandle handle,
DGLContext ctx) {
int DGLArrayUnpinData(DGLArrayHandle handle, DGLContext ctx) {
API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(handle);
NDArray::UnpinContainer(nd_container);
......
#include <gtest/gtest.h>
#include <dgl/array.h>
#include <tuple>
#include <gtest/gtest.h>
#include <set>
#include <tuple>
#include "./common.h"
using namespace dgl;
......@@ -62,7 +64,7 @@ std::set<ETuple<Idx>> ToEdgeSet(COOMatrix mat) {
Idx* col = static_cast<Idx*>(mat.col->data);
Idx* data = static_cast<Idx*>(mat.data->data);
for (int64_t i = 0; i < mat.row->shape[0]; ++i) {
//std::cout << row[i] << " " << col[i] << " " << data[i] << std::endl;
// std::cout << row[i] << " " << col[i] << " " << data[i] << std::endl;
eset.emplace(row[i], col[i], data[i]);
}
return eset;
......@@ -122,7 +124,8 @@ COOMatrix COO(bool has_data) {
template <typename Idx>
std::pair<CSRMatrix, std::vector<int64_t>> CSREtypes(bool has_data) {
IdArray indptr = NDArray::FromVector(std::vector<Idx>({0, 4, 5, 5, 7}));
IdArray indices = NDArray::FromVector(std::vector<Idx>({0, 1, 2, 3, 1, 3, 2}));
IdArray indices =
NDArray::FromVector(std::vector<Idx>({0, 1, 2, 3, 1, 3, 2}));
IdArray data = NDArray::FromVector(std::vector<Idx>({0, 1, 4, 6, 2, 3, 5}));
auto eid2etype_offsets = std::vector<int64_t>({0, 4, 5, 6, 7});
if (has_data)
......@@ -146,8 +149,8 @@ std::pair<COOMatrix, std::vector<int64_t>> COOEtypes(bool has_data) {
template <typename Idx, typename FloatType>
void _TestCSRSampling(bool has_data) {
auto mat = CSR<Idx>(has_data);
FloatArray prob = NDArray::FromVector(
std::vector<FloatType>({.5, .5, .5, .5, .5}));
FloatArray prob =
NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);
......@@ -170,8 +173,7 @@ void _TestCSRSampling(bool has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
prob = NDArray::FromVector(
std::vector<FloatType>({.0, .5, .5, .0, .5}));
prob = NDArray::FromVector(std::vector<FloatType>({.0, .5, .5, .0, .5}));
for (int k = 0; k < 100; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data);
......@@ -243,15 +245,16 @@ void _TestCSRPerEtypeSampling(bool has_data) {
NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5}))
};
NDArray::FromVector(std::vector<FloatType>({.5}))};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -297,10 +300,10 @@ void _TestCSRPerEtypeSampling(bool has_data) {
NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5}))
};
NDArray::FromVector(std::vector<FloatType>({.5}))};
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -322,15 +325,16 @@ void _TestCSRPerEtypeSamplingSorted() {
NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5}))
};
NDArray::FromVector(std::vector<FloatType>({.5}))};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true);
auto eset = ToEdgeSet<Idx>(rst);
int counts = 0;
......@@ -358,10 +362,10 @@ void _TestCSRPerEtypeSamplingSorted() {
NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5}))
};
NDArray::FromVector(std::vector<FloatType>({.5}))};
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true);
auto eset = ToEdgeSet<Idx>(rst);
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));
......@@ -389,18 +393,17 @@ void _TestCSRPerEtypeSamplingUniform(bool has_data) {
auto mat = pair.first;
auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = {
aten::NullArray(),
aten::NullArray(),
aten::NullArray(),
aten::NullArray()
};
aten::NullArray(), aten::NullArray(), aten::NullArray(),
aten::NullArray()};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -449,18 +452,17 @@ void _TestCSRPerEtypeSamplingUniformSorted() {
auto mat = pair.first;
auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = {
aten::NullArray(),
aten::NullArray(),
aten::NullArray(),
aten::NullArray()
};
aten::NullArray(), aten::NullArray(), aten::NullArray(),
aten::NullArray()};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true);
auto rst = CSRRowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, true);
auto eset = ToEdgeSet<Idx>(rst);
int counts = 0;
......@@ -500,12 +502,11 @@ TEST(RowwiseTest, TestCSRPerEtypeSamplingUniform) {
_TestCSRPerEtypeSamplingUniformSorted<int64_t, double>();
}
template <typename Idx, typename FloatType>
void _TestCOOSampling(bool has_data) {
auto mat = COO<Idx>(has_data);
FloatArray prob = NDArray::FromVector(
std::vector<FloatType>({.5, .5, .5, .5, .5}));
FloatArray prob =
NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, true);
......@@ -528,8 +529,7 @@ void _TestCOOSampling(bool has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
prob = NDArray::FromVector(
std::vector<FloatType>({.0, .5, .5, .0, .5}));
prob = NDArray::FromVector(std::vector<FloatType>({.0, .5, .5, .0, .5}));
for (int k = 0; k < 100; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data);
......@@ -604,15 +604,16 @@ void _TestCOOPerEtypeSampling(bool has_data) {
NDArray::FromVector(std::vector<FloatType>({.5, .5, .5, .5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5}))
};
NDArray::FromVector(std::vector<FloatType>({.5}))};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -658,10 +659,10 @@ void _TestCOOPerEtypeSampling(bool has_data) {
NDArray::FromVector(std::vector<FloatType>({.0, .5, .0, .0})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5})),
NDArray::FromVector(std::vector<FloatType>({.5}))
};
NDArray::FromVector(std::vector<FloatType>({.5}))};
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -691,18 +692,17 @@ void _TestCOOPerEtypeSamplingUniform(bool has_data) {
auto mat = pair.first;
auto eid2etype_offset = pair.second;
std::vector<FloatArray> prob = {
aten::NullArray(),
aten::NullArray(),
aten::NullArray(),
aten::NullArray()
};
aten::NullArray(), aten::NullArray(), aten::NullArray(),
aten::NullArray()};
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, true);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWisePerEtypeSampling(mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
auto rst = COORowWisePerEtypeSampling(
mat, rows, eid2etype_offset, {2, 2, 2, 2}, prob, false);
CheckSampledPerEtypeResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
......@@ -759,8 +759,8 @@ TEST(RowwiseTest, TestCOOPerEtypeSamplingUniform) {
template <typename Idx, typename FloatType>
void _TestCSRTopk(bool has_data) {
auto mat = CSR<Idx>(has_data);
FloatArray weight = NDArray::FromVector(
std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f}));
FloatArray weight =
NDArray::FromVector(std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f}));
// -.1, .2, .1, .0, .5
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
......@@ -802,12 +802,11 @@ TEST(RowwiseTest, TestCSRTopk) {
_TestCSRTopk<int64_t, double>(false);
}
template <typename Idx, typename FloatType>
void _TestCOOTopk(bool has_data) {
auto mat = COO<Idx>(has_data);
FloatArray weight = NDArray::FromVector(
std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f}));
FloatArray weight =
NDArray::FromVector(std::vector<FloatType>({.1f, .0f, -.1f, .2f, .5f}));
// -.1, .2, .1, .0, .5
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
......@@ -856,16 +855,11 @@ void _TestCSRSamplingBiased(bool has_data) {
// 1 - 1
// 3 - 2,3
NDArray tag_offset = NDArray::FromVector(
std::vector<Idx>({0, 1, 2,
0, 0, 1,
0, 0, 0,
0, 1, 2}));
std::vector<Idx>({0, 1, 2, 0, 0, 1, 0, 0, 0, 0, 1, 2}));
tag_offset = tag_offset.CreateView({4, 3}, tag_offset->dtype);
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 1, 3}));
FloatArray bias = NDArray::FromVector(
std::vector<FloatType>({0, 0.5})
);
for (int k = 0 ; k < 10 ; ++k) {
FloatArray bias = NDArray::FromVector(std::vector<FloatType>({0, 0.5}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSamplingBiased(mat, rows, 1, tag_offset, bias, false);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
......@@ -879,7 +873,7 @@ void _TestCSRSamplingBiased(bool has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
for (int k = 0 ; k < 10 ; ++k) {
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSamplingBiased(mat, rows, 3, tag_offset, bias, true);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
......
#include <gtest/gtest.h>
#include <vector>
#include <algorithm>
#include <iostream>
#include "./common.h"
#include <vector>
#include "../../src/random/cpu/sample_utils.h"
#include "./common.h"
using namespace dgl;
using namespace dgl::aten;
......@@ -11,7 +13,7 @@ using namespace dgl::aten;
// TODO: adapt this to Random::Choice
template <typename Idx, typename DType>
void _TestWithReplacement(RandomEngine *re) {
void _TestWithReplacement(RandomEngine* re) {
Idx n_categories = 100;
Idx n_rolls = 1000000;
std::vector<DType> _prob;
......@@ -20,12 +22,11 @@ void _TestWithReplacement(RandomEngine *re) {
_prob.push_back(re->Uniform<DType>());
accum += _prob.back();
}
for (Idx i = 0; i < n_categories; ++i)
_prob[i] /= accum;
for (Idx i = 0; i < n_categories; ++i) _prob[i] /= accum;
FloatArray prob = NDArray::FromVector(_prob);
auto _check_given_sampler = [n_categories, n_rolls, &_prob](
utils::BaseSampler<Idx> *s) {
auto _check_given_sampler = [n_categories, n_rolls,
&_prob](utils::BaseSampler<Idx>* s) {
std::vector<Idx> counter(n_categories, 0);
for (Idx i = 0; i < n_rolls; ++i) {
Idx dice = s->Draw();
......@@ -67,14 +68,13 @@ TEST(SampleUtilsTest, TestWithReplacement) {
};
template <typename Idx, typename DType>
void _TestWithoutReplacementOrder(RandomEngine *re) {
void _TestWithoutReplacementOrder(RandomEngine* re) {
// TODO(BarclayII): is there a reliable way to do this test?
std::vector<DType> _prob = {1e6f, 1e-6f, 1e-2f, 1e2f};
FloatArray prob = NDArray::FromVector(_prob);
std::vector<Idx> ground_truth = {0, 3, 2, 1};
auto _check_given_sampler = [&ground_truth](
utils::BaseSampler<Idx> *s) {
auto _check_given_sampler = [&ground_truth](utils::BaseSampler<Idx>* s) {
for (size_t i = 0; i < ground_truth.size(); ++i) {
Idx dice = s->Draw();
ASSERT_EQ(dice, ground_truth[i]);
......@@ -102,22 +102,19 @@ TEST(SampleUtilsTest, TestWithoutReplacementOrder) {
};
template <typename Idx, typename DType>
void _TestWithoutReplacementUnique(RandomEngine *re) {
void _TestWithoutReplacementUnique(RandomEngine* re) {
Idx N = 1000000;
std::vector<DType> _likelihood;
for (Idx i = 0; i < N; ++i)
_likelihood.push_back(re->Uniform<DType>());
for (Idx i = 0; i < N; ++i) _likelihood.push_back(re->Uniform<DType>());
FloatArray likelihood = NDArray::FromVector(_likelihood);
auto _check_given_sampler = [N](
utils::BaseSampler<Idx> *s) {
auto _check_given_sampler = [N](utils::BaseSampler<Idx>* s) {
std::vector<int> cnt(N, 0);
for (Idx i = 0; i < N; ++i) {
Idx dice = s->Draw();
cnt[dice]++;
}
for (Idx i = 0; i < N; ++i)
ASSERT_EQ(cnt[i], 1);
for (Idx i = 0; i < N; ++i) ASSERT_EQ(cnt[i], 1);
};
utils::AliasSampler<Idx, DType, false> as(re, likelihood);
......@@ -242,15 +239,15 @@ void _TestBiasedChoice(RandomEngine* re) {
{
Idx sample_num = 100000;
Idx population = 1000000;
Idx split[] = {0, population/2, population};
Idx split[] = {0, population / 2, population};
FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 3}));
IdArray rst = re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, true);
auto rst_data = static_cast<Idx *>(rst->data);
IdArray rst =
re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, true);
auto rst_data = static_cast<Idx*>(rst->data);
Idx larger = 0;
for (Idx i = 0 ; i < sample_num ; ++i)
if (rst_data[i] >= population / 2)
larger++;
for (Idx i = 0; i < sample_num; ++i)
if (rst_data[i] >= population / 2) larger++;
ASSERT_LE(fabs((double)larger / sample_num - 0.75), 1e-2);
}
// without replacement
......@@ -260,8 +257,9 @@ void _TestBiasedChoice(RandomEngine* re) {
Idx split[] = {0, sample_num, population};
FloatArray bias = NDArray::FromVector(std::vector<FloatType>({1, 0}));
IdArray rst = re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, false);
auto rst_data = static_cast<Idx *>(rst->data);
IdArray rst =
re->BiasedChoice<Idx, FloatType>(sample_num, split, bias, false);
auto rst_data = static_cast<Idx*>(rst->data);
std::set<Idx> idxset;
for (int64_t i = 0; i < sample_num; ++i) {
......
......@@ -2,10 +2,12 @@
#include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <vector>
#include "../../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"
#include "./common.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment