"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "27062c3631b7011a5df45782b8e3d01349d1f3e9"
Unverified Commit f8d4264e authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Feature] Neighbor-hood based sampling APIs (#1251)

* WIP: working on random choices

* light slice

* basic CPU impl

* add python binding; fix CreateFromCOO and CreateFromCSR returning unitgraph

* simple test case works

* fix bug in slicing probability array

* fix bug in getting the correct relation graph

* fix bug in creating placeholder graph

* enable omp

* add cpp test

* sample topk

* add in|out_subgraph

* try fix lint; passed all unittests

* fix lint

* fix msvc compile; add sorted flag and constructors

* fix msvc

* coosort

* COOSort; CSRRowWiseSampling; CSRRowWiseTopk

* WIP: remove DType in CSR and COO; Restrict data array to be IdArray

* fix all CSR ops for missing data array

* compiled

* passed tests

* lint

* test sampling out edge

* test different per-relation fanout/k values

* fix bug in random choice

* finished cpptest

* fix compile

* Add induced edges

* add check

* fixed bug in sampling on hypersparse graph; add tests

* add ascending flag

* in|out_subgraph returns subgraph and induced eid

* address comments

* lint

* fix
parent c7c0fd0e
## DGL Sampler
This directory contains the implementations for graph sampling routines in 0.5+.
### Code Hierarchy
#### Random walks:
* `randomwalks.h:`
* `randomwalks_cpu.h:GenericRandomWalk(hg, seeds, max_num_steps, step)`
* `metapath_randomwalk.h:RandomWalk(hg, seeds, metapath, prob, terminate)`
/*!
* Copyright (c) 2020 by Contributors
* \file graph/sampling/neighbor.cc
* \brief Definition of neighborhood-based sampler APIs.
*/
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include <dgl/sampling/neighbor.h>
#include "../../../c_api_common.h"
#include "../../unit_graph.h"
using namespace dgl::runtime;
using namespace dgl::aten;
namespace dgl {
namespace sampling {
HeteroSubgraph SampleNeighbors(
const HeteroGraphPtr hg,
const std::vector<IdArray>& nodes,
const std::vector<int64_t>& fanouts,
EdgeDir dir,
const std::vector<FloatArray>& prob,
bool replace) {
// sanity check
CHECK_EQ(nodes.size(), hg->NumVertexTypes())
<< "Number of node ID tensors must match the number of node types.";
CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
<< "Number of fanout values must match the number of edge types.";
CHECK_EQ(prob.size(), hg->NumEdgeTypes())
<< "Number of probability tensors must match the number of edge types.";
std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0) {
// No node provided in the type, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
induced_edges[etype] = IdArray::Empty({0}, hg->DataType(), hg->Context());
} else {
// sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::CSR : SparseFormat::CSC;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo;
switch (avail_fmt) {
case SparseFormat::COO:
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseSampling(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes_ntype, fanouts[etype], prob[etype], replace));
} else {
sampled_coo = aten::COORowWiseSampling(
hg->GetCOOMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
}
break;
case SparseFormat::CSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSRMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
break;
case SparseFormat::CSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseSampling(
hg->GetCSCMatrix(etype), nodes_ntype, fanouts[etype], prob[etype], replace);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
LOG(FATAL) << "Unsupported sparse format.";
}
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo);
if (sampled_coo.data.defined()) {
induced_edges[etype] = sampled_coo.data;
} else {
induced_edges[etype] = IdArray::Empty({0}, hg->DataType(), hg->Context());
}
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels);
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges);
return ret;
}
HeteroSubgraph SampleNeighborsTopk(
const HeteroGraphPtr hg,
const std::vector<IdArray>& nodes,
const std::vector<int64_t>& k,
EdgeDir dir,
const std::vector<FloatArray>& weight,
bool ascending) {
// sanity check
CHECK_EQ(nodes.size(), hg->NumVertexTypes())
<< "Number of node ID tensors must match the number of node types.";
CHECK_EQ(k.size(), hg->NumEdgeTypes())
<< "Number of k values must match the number of edge types.";
CHECK_EQ(weight.size(), hg->NumEdgeTypes())
<< "Number of weight tensors must match the number of edge types.";
std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
std::vector<IdArray> induced_edges(hg->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0) {
// No node provided in the type, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context());
induced_edges[etype] = IdArray::Empty({0}, hg->DataType(), hg->Context());
} else {
// sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::CSR : SparseFormat::CSC;
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
COOMatrix sampled_coo;
switch (avail_fmt) {
case SparseFormat::COO:
if (dir == EdgeDir::kIn) {
sampled_coo = aten::COOTranspose(aten::COORowWiseTopk(
aten::COOTranspose(hg->GetCOOMatrix(etype)),
nodes_ntype, k[etype], weight[etype], ascending));
} else {
sampled_coo = aten::COORowWiseTopk(
hg->GetCOOMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
}
break;
case SparseFormat::CSR:
CHECK(dir == EdgeDir::kOut) << "Cannot sample out edges on CSC matrix.";
sampled_coo = aten::CSRRowWiseTopk(
hg->GetCSRMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
break;
case SparseFormat::CSC:
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
sampled_coo = aten::CSRRowWiseTopk(
hg->GetCSCMatrix(etype), nodes_ntype, k[etype], weight[etype], ascending);
sampled_coo = aten::COOTranspose(sampled_coo);
break;
default:
LOG(FATAL) << "Unsupported sparse format.";
}
subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo);
if (sampled_coo.data.defined()) {
induced_edges[etype] = sampled_coo.data;
} else {
induced_edges[etype] = IdArray::Empty({0}, hg->DataType(), hg->Context());
}
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(hg->meta_graph(), subrels);
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = std::move(induced_edges);
return ret;
}
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
const auto& fanouts = ListValueToVector<int64_t>(args[2]);
const std::string dir_str = args[3];
const auto& prob = ListValueToVector<FloatArray>(args[4]);
const bool replace = args[5];
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighbors(
hg.sptr(), nodes, fanouts, dir, prob, replace);
*rv = HeteroSubgraphRef(subg);
});
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
const auto& k = ListValueToVector<int64_t>(args[2]);
const std::string dir_str = args[3];
const auto& weight = ListValueToVector<FloatArray>(args[4]);
const bool ascending = args[5];
CHECK(dir_str == "in" || dir_str == "out")
<< "Invalid edge direction. Must be \"in\" or \"out\".";
EdgeDir dir = (dir_str == "in")? EdgeDir::kIn : EdgeDir::kOut;
std::shared_ptr<HeteroSubgraph> subg(new HeteroSubgraph);
*subg = sampling::SampleNeighborsTopk(
hg.sptr(), nodes, k, dir, weight, ascending);
*rv = HeteroGraphRef(subg);
});
} // namespace sampling
} // namespace dgl
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file graph/sampler/randomwalks.cc * \file graph/sampling/randomwalks.cc
* \brief Dispatcher of different DGL random walks by device type * \brief Dispatcher of different DGL random walks by device type
*/ */
...@@ -113,10 +113,7 @@ DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk") ...@@ -113,10 +113,7 @@ DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk")
TypeArray metapath = args[2]; TypeArray metapath = args[2];
List<Value> prob = args[3]; List<Value> prob = args[3];
std::vector<FloatArray> prob_vec; const auto& prob_vec = ListValueToVector<FloatArray>(prob);
prob_vec.reserve(prob.size());
for (Value val : prob)
prob_vec.push_back(val->data);
auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec); auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec);
List<Value> ret; List<Value> ret;
...@@ -133,10 +130,7 @@ DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart ...@@ -133,10 +130,7 @@ DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart
List<Value> prob = args[3]; List<Value> prob = args[3];
double restart_prob = args[4]; double restart_prob = args[4];
std::vector<FloatArray> prob_vec; const auto& prob_vec = ListValueToVector<FloatArray>(prob);
prob_vec.reserve(prob.size());
for (Value val : prob)
prob_vec.push_back(val->data);
auto result = sampling::RandomWalkWithRestart( auto result = sampling::RandomWalkWithRestart(
hg.sptr(), seeds, metapath, prob_vec, restart_prob); hg.sptr(), seeds, metapath, prob_vec, restart_prob);
...@@ -154,10 +148,7 @@ DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwis ...@@ -154,10 +148,7 @@ DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwis
List<Value> prob = args[3]; List<Value> prob = args[3];
FloatArray restart_prob = args[4]; FloatArray restart_prob = args[4];
std::vector<FloatArray> prob_vec; const auto& prob_vec = ListValueToVector<FloatArray>(prob);
prob_vec.reserve(prob.size());
for (Value val : prob)
prob_vec.push_back(val->data);
auto result = sampling::RandomWalkWithStepwiseRestart( auto result = sampling::RandomWalkWithStepwiseRestart(
hg.sptr(), seeds, metapath, prob_vec, restart_prob); hg.sptr(), seeds, metapath, prob_vec, restart_prob);
......
...@@ -79,8 +79,12 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -79,8 +79,12 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_ = aten::COOMatrix{num_src, num_dst, src, dst}; adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
} }
explicit COO(GraphPtr metagraph, const aten::COOMatrix& coo) COO(GraphPtr metagraph, const aten::COOMatrix& coo)
: BaseHeteroGraph(metagraph), adj_(coo) {} : BaseHeteroGraph(metagraph), adj_(coo) {
// Data index should not be inherited. Edges in COO format are always
// assigned ids from 0 to num_edges - 1.
adj_.data = IdArray();
}
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const {
return 0; return 0;
...@@ -324,6 +328,25 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -324,6 +328,25 @@ class UnitGraph::COO : public BaseHeteroGraph {
} }
} }
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
return adj_;
}
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for COO graph";
return aten::CSRMatrix();
}
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for COO graph";
return aten::CSRMatrix();
}
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::ANY;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch"; CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()]; auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
...@@ -419,7 +442,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -419,7 +442,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
sorted_ = false; sorted_ = false;
} }
explicit CSR(GraphPtr metagraph, const aten::CSRMatrix& csr) CSR(GraphPtr metagraph, const aten::CSRMatrix& csr)
: BaseHeteroGraph(metagraph), adj_(csr) { : BaseHeteroGraph(metagraph), adj_(csr) {
sorted_ = false; sorted_ = false;
} }
...@@ -570,22 +593,22 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -570,22 +593,22 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override { std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_type_t etype, dgl_id_t eid) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override { EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override { EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override { EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
...@@ -616,12 +639,12 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -616,12 +639,12 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override { uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override { DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
...@@ -656,12 +679,12 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -656,12 +679,12 @@ class UnitGraph::CSR : public BaseHeteroGraph {
} }
DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override { DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override { DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
...@@ -671,6 +694,25 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -671,6 +694,25 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return {adj_.indptr, adj_.indices, adj_.data}; return {adj_.indptr, adj_.indices, adj_.data};
} }
aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return aten::COOMatrix();
}
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return aten::CSRMatrix();
}
aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override {
return adj_;
}
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::ANY;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch"; CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch";
auto srcvids = vids[SrcType()], dstvids = vids[DstType()]; auto srcvids = vids[SrcType()], dstvids = vids[DstType()];
...@@ -688,7 +730,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -688,7 +730,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
HeteroSubgraph EdgeSubgraph( HeteroSubgraph EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override { const std::vector<IdArray>& eids, bool preserve_nodes = false) const override {
LOG(INFO) << "Not enabled for CSR graph."; LOG(FATAL) << "Not enabled for CSR graph.";
return {}; return {};
} }
...@@ -982,7 +1024,8 @@ HeteroSubgraph UnitGraph::EdgeSubgraph( ...@@ -982,7 +1024,8 @@ HeteroSubgraph UnitGraph::EdgeSubgraph(
} }
HeteroGraphPtr UnitGraph::CreateFromCOO( HeteroGraphPtr UnitGraph::CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row, IdArray col, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col,
SparseFormat restrict_format) { SparseFormat restrict_format) {
CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1) if (num_vtypes == 1)
...@@ -994,6 +1037,18 @@ HeteroGraphPtr UnitGraph::CreateFromCOO( ...@@ -994,6 +1037,18 @@ HeteroGraphPtr UnitGraph::CreateFromCOO(
new UnitGraph(mg, nullptr, nullptr, coo, restrict_format)); new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
} }
HeteroGraphPtr UnitGraph::CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
COOPtr coo(new COO(mg, mat));
return HeteroGraphPtr(
new UnitGraph(mg, nullptr, nullptr, coo, restrict_format));
}
HeteroGraphPtr UnitGraph::CreateFromCSR( HeteroGraphPtr UnitGraph::CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) { IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) {
...@@ -1005,6 +1060,17 @@ HeteroGraphPtr UnitGraph::CreateFromCSR( ...@@ -1005,6 +1060,17 @@ HeteroGraphPtr UnitGraph::CreateFromCSR(
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
} }
HeteroGraphPtr UnitGraph::CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csr(new CSR(mg, mat));
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
}
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) { if (g->NumBits() == bits) {
return g; return g;
...@@ -1091,10 +1157,8 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR() const { ...@@ -1091,10 +1157,8 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR() const {
UnitGraph::COOPtr UnitGraph::GetCOO() const { UnitGraph::COOPtr UnitGraph::GetCOO() const {
if (!coo_) { if (!coo_) {
if (in_csr_) { if (in_csr_) {
const auto& newadj = aten::CSRToCOO(in_csr_->adj(), true); const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>( const_cast<UnitGraph*>(this)->coo_ = std::make_shared<COO>(meta_graph(), newadj);
meta_graph(),
aten::COOMatrix{newadj.num_cols, newadj.num_rows, newadj.col, newadj.row});
} else { } else {
CHECK(out_csr_) << "Both CSR are missing."; CHECK(out_csr_) << "Both CSR are missing.";
const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true); const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
...@@ -1104,15 +1168,15 @@ UnitGraph::COOPtr UnitGraph::GetCOO() const { ...@@ -1104,15 +1168,15 @@ UnitGraph::COOPtr UnitGraph::GetCOO() const {
return coo_; return coo_;
} }
aten::CSRMatrix UnitGraph::GetInCSRMatrix() const { aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const {
return GetInCSR()->adj(); return GetInCSR()->adj();
} }
aten::CSRMatrix UnitGraph::GetOutCSRMatrix() const { aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const {
return GetOutCSR()->adj(); return GetOutCSR()->adj();
} }
aten::COOMatrix UnitGraph::GetCOOMatrix() const { aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
return GetCOO()->adj(); return GetCOO()->adj();
} }
...@@ -1186,7 +1250,7 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1186,7 +1250,7 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
// Using Out CSR // Using Out CSR
void UnitGraph::Save(dmlc::Stream* fs) const { void UnitGraph::Save(dmlc::Stream* fs) const {
// Following CreateFromCSR signature // Following CreateFromCSR signature
aten::CSRMatrix csr_matrix = GetOutCSRMatrix(); aten::CSRMatrix csr_matrix = GetCSRMatrix(0);
uint64_t num_vtypes = NumVertexTypes(); uint64_t num_vtypes = NumVertexTypes();
uint64_t num_src = NumVertices(SrcType()); uint64_t num_src = NumVertices(SrcType());
uint64_t num_dst = NumVertices(DstType()); uint64_t num_dst = NumVertices(DstType());
......
...@@ -15,11 +15,15 @@ ...@@ -15,11 +15,15 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory>
#include "../c_api_common.h" #include "../c_api_common.h"
namespace dgl { namespace dgl {
class UnitGraph;
typedef std::shared_ptr<UnitGraph> UnitGraphPtr;
/*! /*!
* \brief UnitGraph graph * \brief UnitGraph graph
* *
...@@ -144,17 +148,34 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -144,17 +148,34 @@ class UnitGraph : public BaseHeteroGraph {
const std::vector<IdArray>& eids, bool preserve_nodes = false) const override; const std::vector<IdArray>& eids, bool preserve_nodes = false) const override;
// creators // creators
/*! \brief Create a graph with no edges */
static HeteroGraphPtr Empty(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
DLDataType dtype, DLContext ctx) {
IdArray row = IdArray::Empty({0}, dtype, ctx);
IdArray col = IdArray::Empty({0}, dtype, ctx);
return CreateFromCOO(num_vtypes, num_src, num_dst, row, col);
}
/*! \brief Create a graph from COO arrays */ /*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::ANY); IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::ANY);
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format = SparseFormat::ANY);
/*! \brief Create a graph from (out) CSR arrays */ /*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::ANY); SparseFormat restrict_format = SparseFormat::ANY);
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::ANY);
/*! \brief Convert the graph to use the given number of bits for storage */ /*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
...@@ -170,14 +191,18 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -170,14 +191,18 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Return the COO format. Create from other format if not exist. */ /*! \return Return the COO format. Create from other format if not exist. */
COOPtr GetCOO() const; COOPtr GetCOO() const;
/*! \return Return the in-edge CSR in the matrix form */ /*! \return Return the COO matrix form */
aten::CSRMatrix GetInCSRMatrix() const; aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override;
/*! \return Return the in-edge CSC in the matrix form */
aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override;
/*! \return Return the out-edge CSR in the matrix form */ /*! \return Return the out-edge CSR in the matrix form */
aten::CSRMatrix GetOutCSRMatrix() const; aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override;
/*! \return Return the COO matrix form */ SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
aten::COOMatrix GetCOOMatrix() const; return SelectFormat(preferred_format);
}
/*! \return Load UnitGraph from stream, using CSRMatrix*/ /*! \return Load UnitGraph from stream, using CSRMatrix*/
bool Load(dmlc::Stream* fs); bool Load(dmlc::Stream* fs);
......
...@@ -249,11 +249,11 @@ class UnitGraphCSRWrapper : public CSRWrapper { ...@@ -249,11 +249,11 @@ class UnitGraphCSRWrapper : public CSRWrapper {
gptr_(graph) { } gptr_(graph) { }
aten::CSRMatrix GetInCSRMatrix() const override { aten::CSRMatrix GetInCSRMatrix() const override {
return gptr_->GetInCSRMatrix(); return gptr_->GetCSCMatrix(0);
} }
aten::CSRMatrix GetOutCSRMatrix() const override { aten::CSRMatrix GetOutCSRMatrix() const override {
return gptr_->GetOutCSRMatrix(); return gptr_->GetCSRMatrix(0);
} }
DGLContext Context() const override { DGLContext Context() const override {
......
...@@ -5,13 +5,7 @@ ...@@ -5,13 +5,7 @@
*/ */
#include <dgl/random.h> #include <dgl/random.h>
#include <algorithm> #include <dgl/array.h>
#include <utility>
#include <queue>
#include <cstdlib>
#include <cmath>
#include <numeric>
#include <limits>
#include <vector> #include <vector>
#include "sample_utils.h" #include "sample_utils.h"
...@@ -19,8 +13,9 @@ namespace dgl { ...@@ -19,8 +13,9 @@ namespace dgl {
template<typename IdxType> template<typename IdxType>
IdxType RandomEngine::Choice(FloatArray prob) { IdxType RandomEngine::Choice(FloatArray prob) {
IdxType ret; IdxType ret = 0;
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", { ATEN_FLOAT_TYPE_SWITCH(prob->dtype, ValueType, "probability", {
// TODO(minjie): allow choosing different sampling algorithms
utils::TreeSampler<IdxType, ValueType, true> sampler(this, prob); utils::TreeSampler<IdxType, ValueType, true> sampler(this, prob);
ret = sampler.Draw(); ret = sampler.Draw();
}); });
...@@ -30,4 +25,65 @@ IdxType RandomEngine::Choice(FloatArray prob) { ...@@ -30,4 +25,65 @@ IdxType RandomEngine::Choice(FloatArray prob) {
template int32_t RandomEngine::Choice<int32_t>(FloatArray); template int32_t RandomEngine::Choice<int32_t>(FloatArray);
template int64_t RandomEngine::Choice<int64_t>(FloatArray); template int64_t RandomEngine::Choice<int64_t>(FloatArray);
template<typename IdxType, typename FloatType>
IdArray RandomEngine::Choice(int64_t num, FloatArray prob, bool replace) {
const int64_t N = prob->shape[0];
if (!replace)
CHECK_LE(num, N) << "Cannot take more sample than population when 'replace=false'";
if (num == N && !replace)
return aten::Range(0, N, sizeof(IdxType) * 8, DLContext{kDLCPU, 0});
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
IdxType* ret_data = static_cast<IdxType*>(ret->data);
utils::BaseSampler<IdxType>* sampler = nullptr;
if (replace) {
sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob);
} else {
sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);
}
for (int64_t i = 0; i < num; ++i)
ret_data[i] = sampler->Draw();
delete sampler;
return ret;
}
template IdArray RandomEngine::Choice<int32_t, float>(
int64_t num, FloatArray prob, bool replace);
template IdArray RandomEngine::Choice<int64_t, float>(
int64_t num, FloatArray prob, bool replace);
template IdArray RandomEngine::Choice<int32_t, double>(
int64_t num, FloatArray prob, bool replace);
template IdArray RandomEngine::Choice<int64_t, double>(
int64_t num, FloatArray prob, bool replace);
template <typename IdxType>
IdArray RandomEngine::UniformChoice(int64_t num, int64_t population, bool replace) {
if (!replace)
CHECK_LE(num, population) << "Cannot take more sample than population when 'replace=false'";
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
IdxType* ret_data = static_cast<IdxType*>(ret->data);
if (replace) {
for (int64_t i = 0; i < num; ++i)
ret_data[i] = RandInt(population);
} else {
// time: O(population), space: O(num)
for (int64_t i = 0; i < num; ++i)
ret_data[i] = i;
for (uint64_t i = num; i < population; ++i) {
const int64_t j = RandInt(i);
if (j < num)
ret_data[j] = i;
}
}
return ret;
}
template IdArray RandomEngine::UniformChoice<int32_t>(
int64_t num, int64_t population, bool replace);
template IdArray RandomEngine::UniformChoice<int64_t>(
int64_t num, int64_t population, bool replace);
}; // namespace dgl }; // namespace dgl
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
namespace dgl { namespace dgl {
namespace utils { namespace utils {
template < /*! \brief Base sampler class */
typename Idx, template <typename Idx>
typename DType,
bool replace>
class BaseSampler { class BaseSampler {
public: public:
virtual ~BaseSampler() = default;
/*! \brief Draw one integer sample */
virtual Idx Draw() { virtual Idx Draw() {
LOG(INFO) << "Not implemented yet."; LOG(INFO) << "Not implemented yet.";
return 0; return 0;
...@@ -43,7 +43,7 @@ template < ...@@ -43,7 +43,7 @@ template <
typename Idx, typename Idx,
typename DType, typename DType,
bool replace> bool replace>
class AliasSampler: public BaseSampler<Idx, DType, replace> { class AliasSampler: public BaseSampler<Idx> {
private: private:
RandomEngine *re; RandomEngine *re;
Idx N; Idx N;
...@@ -165,7 +165,7 @@ template < ...@@ -165,7 +165,7 @@ template <
typename Idx, typename Idx,
typename DType, typename DType,
bool replace> bool replace>
class CDFSampler: public BaseSampler<Idx, DType, replace> { class CDFSampler: public BaseSampler<Idx> {
private: private:
RandomEngine *re; RandomEngine *re;
Idx N; Idx N;
...@@ -252,7 +252,7 @@ template < ...@@ -252,7 +252,7 @@ template <
typename Idx, typename Idx,
typename DType, typename DType,
bool replace> bool replace>
class TreeSampler: public BaseSampler<Idx, DType, replace> { class TreeSampler: public BaseSampler<Idx> {
private: private:
RandomEngine *re; RandomEngine *re;
std::vector<DType> weight; // accumulated likelihood of subtrees. std::vector<DType> weight; // accumulated likelihood of subtrees.
......
...@@ -103,7 +103,321 @@ def test_pack_traces(): ...@@ -103,7 +103,321 @@ def test_pack_traces():
assert F.array_equal(result[2], F.tensor([2, 7], dtype=F.int64)) assert F.array_equal(result[2], F.tensor([2, 7], dtype=F.int64))
assert F.array_equal(result[3], F.tensor([0, 2], dtype=F.int64)) assert F.array_equal(result[3], F.tensor([0, 2], dtype=F.int64))
def _gen_neighbor_sampling_test_graph(hypersparse, reverse):
if hypersparse:
# should crash if allocated a CSR
card = 1 << 50
card2 = (1 << 50, 1 << 50)
else:
card = None
card2 = None
if reverse:
g = dgl.graph([(0,1),(0,2),(0,3),(1,0),(1,2),(1,3),(2,0)],
'user', 'follow', card=card)
g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
g1 = dgl.bipartite([(0,0),(1,0),(2,1),(2,3)], 'game', 'play', 'user', card=card2)
g1.edata['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32)
g2 = dgl.bipartite([(0,2),(1,2),(2,2),(0,1),(3,1),(0,0)], 'user', 'liked-by', 'game', card=card2)
g2.edata['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32)
g3 = dgl.bipartite([(0,0),(0,1),(0,2),(0,3)], 'coin', 'flips', 'user', card=card2)
hg = dgl.hetero_from_relations([g, g1, g2, g3])
else:
g = dgl.graph([(1,0),(2,0),(3,0),(0,1),(2,1),(3,1),(0,2)],
'user', 'follow', card=card)
g.edata['prob'] = F.tensor([.5, .5, 0., .5, .5, 0., 1.], dtype=F.float32)
g1 = dgl.bipartite([(0,0),(0,1),(1,2),(3,2)], 'user', 'play', 'game', card=card2)
g1.edata['prob'] = F.tensor([.8, .5, .5, .5], dtype=F.float32)
g2 = dgl.bipartite([(2,0),(2,1),(2,2),(1,0),(1,3),(0,0)], 'game', 'liked-by', 'user', card=card2)
g2.edata['prob'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32)
g3 = dgl.bipartite([(0,0),(1,0),(2,0),(3,0)], 'user', 'flips', 'coin', card=card2)
hg = dgl.hetero_from_relations([g, g1, g2, g3])
return g, hg
def _gen_neighbor_topk_test_graph(hypersparse, reverse):
if hypersparse:
# should crash if allocated a CSR
card = 1 << 50
card2 = (1 << 50, 1 << 50)
else:
card = None
card2 = None
if reverse:
g = dgl.graph([(0,1),(0,2),(0,3),(1,0),(1,2),(1,3),(2,0)],
'user', 'follow')
g.edata['weight'] = F.tensor([.5, .3, 0., -5., 22., 0., 1.], dtype=F.float32)
g1 = dgl.bipartite([(0,0),(1,0),(2,1),(2,3)], 'game', 'play', 'user')
g1.edata['weight'] = F.tensor([.8, .5, .4, .5], dtype=F.float32)
g2 = dgl.bipartite([(0,2),(1,2),(2,2),(0,1),(3,1),(0,0)], 'user', 'liked-by', 'game')
g2.edata['weight'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32)
g3 = dgl.bipartite([(0,0),(0,1),(0,2),(0,3)], 'coin', 'flips', 'user')
g3.edata['weight'] = F.tensor([10, 2, 13, -1], dtype=F.float32)
hg = dgl.hetero_from_relations([g, g1, g2, g3])
else:
g = dgl.graph([(1,0),(2,0),(3,0),(0,1),(2,1),(3,1),(0,2)],
'user', 'follow')
g.edata['weight'] = F.tensor([.5, .3, 0., -5., 22., 0., 1.], dtype=F.float32)
g1 = dgl.bipartite([(0,0),(0,1),(1,2),(3,2)], 'user', 'play', 'game')
g1.edata['weight'] = F.tensor([.8, .5, .4, .5], dtype=F.float32)
g2 = dgl.bipartite([(2,0),(2,1),(2,2),(1,0),(1,3),(0,0)], 'game', 'liked-by', 'user')
g2.edata['weight'] = F.tensor([.3, .5, .2, .5, .1, .1], dtype=F.float32)
g3 = dgl.bipartite([(0,0),(1,0),(2,0),(3,0)], 'user', 'flips', 'coin')
g3.edata['weight'] = F.tensor([10, 2, 13, -1], dtype=F.float32)
hg = dgl.hetero_from_relations([g, g1, g2, g3])
return g, hg
def _test_sample_neighbors(hypersparse):
g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)
def _test1(p, replace):
for i in range(10):
subg = dgl.sampling.sample_neighbors(g, [0, 1], 2, prob=p, replace=replace)
assert subg.number_of_nodes() == g.number_of_nodes()
assert subg.number_of_edges() == 4
u, v = subg.edges()
assert set(F.asnumpy(F.unique(v))) == {0, 1}
assert F.array_equal(g.has_edges_between(u, v), F.ones((4,), dtype=F.int64))
assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
if not replace:
# check no duplication
assert len(edge_set) == 4
if p is not None:
assert not (3, 0) in edge_set
assert not (3, 1) in edge_set
_test1(None, True) # w/ replacement, uniform
_test1(None, False) # w/o replacement, uniform
_test1('prob', True) # w/ replacement
_test1('prob', False) # w/o replacement
def _test2(p, replace): # fanout > #neighbors
for i in range(10):
subg = dgl.sampling.sample_neighbors(g, [0, 2], 2, prob=p, replace=replace)
assert subg.number_of_nodes() == g.number_of_nodes()
num_edges = 4 if replace else 3
assert subg.number_of_edges() == num_edges
u, v = subg.edges()
assert set(F.asnumpy(F.unique(v))) == {0, 2}
assert F.array_equal(g.has_edges_between(u, v), F.ones((num_edges,), dtype=F.int64))
assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
if not replace:
# check no duplication
assert len(edge_set) == num_edges
if p is not None:
assert not (3, 0) in edge_set
_test2(None, True) # w/ replacement, uniform
_test2(None, False) # w/o replacement, uniform
_test2('prob', True) # w/ replacement
_test2('prob', False) # w/o replacement
def _test3(p, replace):
for i in range(10):
subg = dgl.sampling.sample_neighbors(hg, {'user' : [0,1], 'game' : 0}, 2, prob=p, replace=replace)
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 4
assert subg['play'].number_of_edges() == 2 if replace else 1
assert subg['liked-by'].number_of_edges() == 4 if replace else 3
assert subg['flips'].number_of_edges() == 0
_test3(None, True) # w/ replacement, uniform
_test3(None, False) # w/o replacement, uniform
_test3('prob', True) # w/ replacement
_test3('prob', False) # w/o replacement
# test different fanouts for different relations
for i in range(10):
subg = dgl.sampling.sample_neighbors(hg, {'user' : [0,1], 'game' : 0}, [1, 2, 0, 2])
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 2
assert subg['play'].number_of_edges() == 2
assert subg['liked-by'].number_of_edges() == 0
assert subg['flips'].number_of_edges() == 0
def _test_sample_neighbors_outedge(hypersparse):
g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True)
def _test1(p, replace):
for i in range(10):
subg = dgl.sampling.sample_neighbors(g, [0, 1], 2, prob=p, replace=replace, edge_dir='out')
assert subg.number_of_nodes() == g.number_of_nodes()
assert subg.number_of_edges() == 4
u, v = subg.edges()
assert set(F.asnumpy(F.unique(u))) == {0, 1}
assert F.array_equal(g.has_edges_between(u, v), F.ones((4,), dtype=F.int64))
assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
if not replace:
# check no duplication
assert len(edge_set) == 4
if p is not None:
assert not (0, 3) in edge_set
assert not (1, 3) in edge_set
_test1(None, True) # w/ replacement, uniform
_test1(None, False) # w/o replacement, uniform
_test1('prob', True) # w/ replacement
_test1('prob', False) # w/o replacement
def _test2(p, replace): # fanout > #neighbors
for i in range(10):
subg = dgl.sampling.sample_neighbors(g, [0, 2], 2, prob=p, replace=replace, edge_dir='out')
assert subg.number_of_nodes() == g.number_of_nodes()
num_edges = 4 if replace else 3
assert subg.number_of_edges() == num_edges
u, v = subg.edges()
assert set(F.asnumpy(F.unique(u))) == {0, 2}
assert F.array_equal(g.has_edges_between(u, v), F.ones((num_edges,), dtype=F.int64))
assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
if not replace:
# check no duplication
assert len(edge_set) == num_edges
if p is not None:
assert not (0, 3) in edge_set
_test2(None, True) # w/ replacement, uniform
_test2(None, False) # w/o replacement, uniform
_test2('prob', True) # w/ replacement
_test2('prob', False) # w/o replacement
def _test3(p, replace):
for i in range(10):
subg = dgl.sampling.sample_neighbors(hg, {'user' : [0,1], 'game' : 0}, 2, prob=p, replace=replace, edge_dir='out')
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 4
assert subg['play'].number_of_edges() == 2 if replace else 1
assert subg['liked-by'].number_of_edges() == 4 if replace else 3
assert subg['flips'].number_of_edges() == 0
_test3(None, True) # w/ replacement, uniform
_test3(None, False) # w/o replacement, uniform
_test3('prob', True) # w/ replacement
_test3('prob', False) # w/o replacement
def _test_sample_neighbors_topk(hypersparse):
g, hg = _gen_neighbor_topk_test_graph(hypersparse, False)
def _test1():
subg = dgl.sampling.sample_neighbors_topk(g, [0, 1], 2, 'weight')
assert subg.number_of_nodes() == g.number_of_nodes()
assert subg.number_of_edges() == 4
u, v = subg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
assert edge_set == {(2,0),(1,0),(2,1),(3,1)}
_test1()
def _test2(): # k > #neighbors
subg = dgl.sampling.sample_neighbors_topk(g, [0, 2], 2, 'weight')
assert subg.number_of_nodes() == g.number_of_nodes()
assert subg.number_of_edges() == 3
u, v = subg.edges()
assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(2,0),(1,0),(0,2)}
_test2()
def _test3():
subg = dgl.sampling.sample_neighbors_topk(hg, {'user' : [0,1], 'game' : 0}, 2, 'weight')
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
u, v = subg['follow'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['follow'].edge_ids(u, v), subg['follow'].edata[dgl.EID])
assert edge_set == {(2,0),(1,0),(2,1),(3,1)}
u, v = subg['play'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['play'].edge_ids(u, v), subg['play'].edata[dgl.EID])
assert edge_set == {(0,0)}
u, v = subg['liked-by'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['liked-by'].edge_ids(u, v), subg['liked-by'].edata[dgl.EID])
assert edge_set == {(2,0),(2,1),(1,0)}
assert subg['flips'].number_of_edges() == 0
_test3()
# test different k for different relations
subg = dgl.sampling.sample_neighbors_topk(hg, {'user' : [0,1], 'game' : 0}, [1, 2, 0, 2], 'weight')
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
assert subg['follow'].number_of_edges() == 2
assert subg['play'].number_of_edges() == 1
assert subg['liked-by'].number_of_edges() == 0
assert subg['flips'].number_of_edges() == 0
def _test_sample_neighbors_topk_outedge(hypersparse):
g, hg = _gen_neighbor_topk_test_graph(hypersparse, True)
def _test1():
subg = dgl.sampling.sample_neighbors_topk(g, [0, 1], 2, 'weight', edge_dir='out')
assert subg.number_of_nodes() == g.number_of_nodes()
assert subg.number_of_edges() == 4
u, v = subg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
assert edge_set == {(0,2),(0,1),(1,2),(1,3)}
_test1()
def _test2(): # k > #neighbors
subg = dgl.sampling.sample_neighbors_topk(g, [0, 2], 2, 'weight', edge_dir='out')
assert subg.number_of_nodes() == g.number_of_nodes()
assert subg.number_of_edges() == 3
u, v = subg.edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(g.edge_ids(u, v), subg.edata[dgl.EID])
assert edge_set == {(0,2),(0,1),(2,0)}
_test2()
def _test3():
subg = dgl.sampling.sample_neighbors_topk(hg, {'user' : [0,1], 'game' : 0}, 2, 'weight', edge_dir='out')
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
u, v = subg['follow'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['follow'].edge_ids(u, v), subg['follow'].edata[dgl.EID])
assert edge_set == {(0,2),(0,1),(1,2),(1,3)}
u, v = subg['play'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['play'].edge_ids(u, v), subg['play'].edata[dgl.EID])
assert edge_set == {(0,0)}
u, v = subg['liked-by'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['liked-by'].edge_ids(u, v), subg['liked-by'].edata[dgl.EID])
assert edge_set == {(0,2),(1,2),(0,1)}
assert subg['flips'].number_of_edges() == 0
_test3()
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors():
_test_sample_neighbors(False)
_test_sample_neighbors(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_outedge():
_test_sample_neighbors_outedge(False)
_test_sample_neighbors_outedge(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_topk():
_test_sample_neighbors_topk(False)
_test_sample_neighbors_topk(True)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU sample neighbors not implemented")
def test_sample_neighbors_topk_outedge():
_test_sample_neighbors_topk_outedge(False)
_test_sample_neighbors_topk_outedge(True)
if __name__ == '__main__': if __name__ == '__main__':
test_random_walk() test_random_walk()
test_pack_traces() test_pack_traces()
test_sample_neighbors()
test_sample_neighbors_outedge()
test_sample_neighbors_topk()
test_sample_neighbors_topk_outedge()
...@@ -5,12 +5,12 @@ import dgl ...@@ -5,12 +5,12 @@ import dgl
import dgl.function as fn import dgl.function as fn
import backend as F import backend as F
from dgl.graph_index import from_scipy_sparse_matrix from dgl.graph_index import from_scipy_sparse_matrix
import unittest
D = 5 D = 5
# line graph related # line graph related
def test_line_graph(): def test_line_graph():
N = 5 N = 5
G = dgl.DGLGraph(nx.star_graph(N)) G = dgl.DGLGraph(nx.star_graph(N))
...@@ -231,6 +231,56 @@ def test_partition(): ...@@ -231,6 +231,56 @@ def test_partition():
block_eids2 = F.asnumpy(F.gather_row(subg.parent_eid, block_eids2)) block_eids2 = F.asnumpy(F.gather_row(subg.parent_eid, block_eids2))
assert np.all(np.sort(block_eids1) == np.sort(block_eids2)) assert np.all(np.sort(block_eids1) == np.sort(block_eids2))
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_in_subgraph():
g1 = dgl.graph([(1,0),(2,0),(3,0),(0,1),(2,1),(3,1),(0,2)], 'user', 'follow')
g2 = dgl.bipartite([(0,0),(0,1),(1,2),(3,2)], 'user', 'play', 'game')
g3 = dgl.bipartite([(2,0),(2,1),(2,2),(1,0),(1,3),(0,0)], 'game', 'liked-by', 'user')
g4 = dgl.bipartite([(0,0),(1,0),(2,0),(3,0)], 'user', 'flips', 'coin')
hg = dgl.hetero_from_relations([g1, g2, g3, g4])
subg = dgl.in_subgraph(hg, {'user' : [0,1], 'game' : 0})
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
u, v = subg['follow'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['follow'].edge_ids(u, v), subg['follow'].edata[dgl.EID])
assert edge_set == {(1,0),(2,0),(3,0),(0,1),(2,1),(3,1)}
u, v = subg['play'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['play'].edge_ids(u, v), subg['play'].edata[dgl.EID])
assert edge_set == {(0,0)}
u, v = subg['liked-by'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert F.array_equal(hg['liked-by'].edge_ids(u, v), subg['liked-by'].edata[dgl.EID])
assert edge_set == {(2,0),(2,1),(1,0),(0,0)}
assert subg['flips'].number_of_edges() == 0
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_out_subgraph():
g1 = dgl.graph([(1,0),(2,0),(3,0),(0,1),(2,1),(3,1),(0,2)], 'user', 'follow')
g2 = dgl.bipartite([(0,0),(0,1),(1,2),(3,2)], 'user', 'play', 'game')
g3 = dgl.bipartite([(2,0),(2,1),(2,2),(1,0),(1,3),(0,0)], 'game', 'liked-by', 'user')
g4 = dgl.bipartite([(0,0),(1,0),(2,0),(3,0)], 'user', 'flips', 'coin')
hg = dgl.hetero_from_relations([g1, g2, g3, g4])
subg = dgl.out_subgraph(hg, {'user' : [0,1], 'game' : 0})
assert len(subg.ntypes) == 3
assert len(subg.etypes) == 4
u, v = subg['follow'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(1,0),(0,1),(0,2)}
assert F.array_equal(hg['follow'].edge_ids(u, v), subg['follow'].edata[dgl.EID])
u, v = subg['play'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0,0),(0,1),(1,2)}
assert F.array_equal(hg['play'].edge_ids(u, v), subg['play'].edata[dgl.EID])
u, v = subg['liked-by'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0,0)}
assert F.array_equal(hg['liked-by'].edge_ids(u, v), subg['liked-by'].edata[dgl.EID])
u, v = subg['flips'].edges()
edge_set = set(zip(list(F.asnumpy(u)), list(F.asnumpy(v))))
assert edge_set == {(0,0),(1,0)}
assert F.array_equal(hg['flips'].edge_ids(u, v), subg['flips'].edata[dgl.EID])
if __name__ == '__main__': if __name__ == '__main__':
test_line_graph() test_line_graph()
...@@ -245,3 +295,5 @@ if __name__ == '__main__': ...@@ -245,3 +295,5 @@ if __name__ == '__main__':
test_remove_self_loop() test_remove_self_loop()
test_add_self_loop() test_add_self_loop()
test_partition() test_partition()
test_in_subgraph()
test_out_subgraph()
...@@ -35,6 +35,17 @@ inline bool ArrayEQ(dgl::runtime::NDArray a1, dgl::runtime::NDArray a2) { ...@@ -35,6 +35,17 @@ inline bool ArrayEQ(dgl::runtime::NDArray a1, dgl::runtime::NDArray a2) {
return true; return true;
} }
template <typename T>
inline bool IsInArray(dgl::runtime::NDArray a, T x) {
if (!a.defined() || a->shape[0] == 0)
return false;
for (int64_t i = 0; i < a->shape[0]; ++i) {
if (x == static_cast<T*>(a->data)[i])
return true;
}
return false;
}
static constexpr DLContext CTX = DLContext{kDLCPU, 0}; static constexpr DLContext CTX = DLContext{kDLCPU, 0};
#endif // TEST_COMMON_H_ #endif // TEST_COMMON_H_
#include <gtest/gtest.h>
#include <dgl/array.h>
#include <tuple>
#include <set>
#include "./common.h"
using namespace dgl;
using namespace dgl::runtime;
using namespace dgl::aten;
template <typename Idx>
using ETuple = std::tuple<Idx, Idx, Idx>;
template <typename Idx>
std::set<ETuple<Idx>> AllEdgeSet(bool has_data) {
if (has_data) {
std::set<ETuple<Idx>> eset;
eset.insert(ETuple<Idx>{0, 0, 2});
eset.insert(ETuple<Idx>{0, 1, 3});
eset.insert(ETuple<Idx>{1, 1, 0});
eset.insert(ETuple<Idx>{3, 2, 1});
eset.insert(ETuple<Idx>{3, 3, 4});
return eset;
} else {
std::set<ETuple<Idx>> eset;
eset.insert(ETuple<Idx>{0, 0, 0});
eset.insert(ETuple<Idx>{0, 1, 1});
eset.insert(ETuple<Idx>{1, 1, 2});
eset.insert(ETuple<Idx>{3, 2, 3});
eset.insert(ETuple<Idx>{3, 3, 4});
return eset;
}
}
template <typename Idx>
std::set<ETuple<Idx>> ToEdgeSet(COOMatrix mat) {
std::set<ETuple<Idx>> eset;
Idx* row = static_cast<Idx*>(mat.row->data);
Idx* col = static_cast<Idx*>(mat.col->data);
Idx* data = static_cast<Idx*>(mat.data->data);
for (int64_t i = 0; i < mat.row->shape[0]; ++i) {
//std::cout << row[i] << " " << col[i] << " " << data[i] << std::endl;
eset.emplace(row[i], col[i], data[i]);
}
return eset;
}
template <typename Idx>
void CheckSampledResult(COOMatrix mat, IdArray rows, bool has_data) {
ASSERT_EQ(mat.num_rows, 4);
ASSERT_EQ(mat.num_cols, 4);
Idx* row = static_cast<Idx*>(mat.row->data);
Idx* col = static_cast<Idx*>(mat.col->data);
Idx* data = static_cast<Idx*>(mat.data->data);
const auto& gt = AllEdgeSet<Idx>(has_data);
for (int64_t i = 0; i < mat.row->shape[0]; ++i) {
ASSERT_TRUE(gt.count(std::make_tuple(row[i], col[i], data[i])));
ASSERT_TRUE(IsInArray(rows, row[i]));
}
}
template <typename Idx>
CSRMatrix CSR(bool has_data) {
IdArray indptr = NDArray::FromVector(std::vector<Idx>({0, 2, 3, 3, 5}));
IdArray indices = NDArray::FromVector(std::vector<Idx>({0, 1, 1, 2, 3}));
IdArray data = NDArray::FromVector(std::vector<Idx>({2, 3, 0, 1, 4}));
if (has_data)
return CSRMatrix(4, 4, indptr, indices, data);
else
return CSRMatrix(4, 4, indptr, indices);
}
template <typename Idx>
COOMatrix COO(bool has_data) {
IdArray row = NDArray::FromVector(std::vector<Idx>({0, 0, 1, 3, 3}));
IdArray col = NDArray::FromVector(std::vector<Idx>({0, 1, 1, 2, 3}));
IdArray data = NDArray::FromVector(std::vector<Idx>({2, 3, 0, 1, 4}));
if (has_data)
return COOMatrix(4, 4, row, col, data);
else
return COOMatrix(4, 4, row, col);
}
template <typename Idx, typename FloatType>
void _TestCSRSampling(bool has_data) {
auto mat = CSR<Idx>(has_data);
FloatArray prob = NDArray::FromVector(
std::vector<FloatType>({.5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, false);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 4);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
prob = NDArray::FromVector(
std::vector<FloatType>({.0, .5, .5, .0, .5}));
for (int k = 0; k < 100; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
ASSERT_FALSE(eset.count(std::make_tuple(0, 1, 3)));
} else {
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3)));
}
}
}
TEST(RowwiseTest, TestCSRSampling) {
_TestCSRSampling<int32_t, float>(true);
_TestCSRSampling<int64_t, float>(true);
_TestCSRSampling<int32_t, double>(true);
_TestCSRSampling<int64_t, double>(true);
_TestCSRSampling<int32_t, float>(false);
_TestCSRSampling<int64_t, float>(false);
_TestCSRSampling<int32_t, double>(false);
_TestCSRSampling<int64_t, double>(false);
}
template <typename Idx, typename FloatType>
void _TestCSRSamplingUniform(bool has_data) {
auto mat = CSR<Idx>(has_data);
FloatArray prob;
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = CSRRowWiseSampling(mat, rows, 2, prob, false);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
}
TEST(RowwiseTest, TestCSRSamplingUniform) {
_TestCSRSamplingUniform<int32_t, float>(true);
_TestCSRSamplingUniform<int64_t, float>(true);
_TestCSRSamplingUniform<int32_t, double>(true);
_TestCSRSamplingUniform<int64_t, double>(true);
_TestCSRSamplingUniform<int32_t, float>(false);
_TestCSRSamplingUniform<int64_t, float>(false);
_TestCSRSamplingUniform<int32_t, double>(false);
_TestCSRSamplingUniform<int64_t, double>(false);
}
template <typename Idx, typename FloatType>
void _TestCOOSampling(bool has_data) {
auto mat = COO<Idx>(has_data);
FloatArray prob = NDArray::FromVector(
std::vector<FloatType>({.5, .5, .5, .5, .5}));
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, false);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 4);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
prob = NDArray::FromVector(
std::vector<FloatType>({.0, .5, .5, .0, .5}));
for (int k = 0; k < 100; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
ASSERT_FALSE(eset.count(std::make_tuple(0, 1, 3)));
} else {
ASSERT_FALSE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_FALSE(eset.count(std::make_tuple(3, 2, 3)));
}
}
}
TEST(RowwiseTest, TestCOOSampling) {
_TestCOOSampling<int32_t, float>(true);
_TestCOOSampling<int64_t, float>(true);
_TestCOOSampling<int32_t, double>(true);
_TestCOOSampling<int64_t, double>(true);
_TestCOOSampling<int32_t, float>(false);
_TestCOOSampling<int64_t, float>(false);
_TestCOOSampling<int32_t, double>(false);
_TestCOOSampling<int64_t, double>(false);
}
template <typename Idx, typename FloatType>
void _TestCOOSamplingUniform(bool has_data) {
auto mat = COO<Idx>(has_data);
FloatArray prob;
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
for (int k = 0; k < 10; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, true);
CheckSampledResult<Idx>(rst, rows, has_data);
}
for (int k = 0; k < 10; ++k) {
auto rst = COORowWiseSampling(mat, rows, 2, prob, false);
CheckSampledResult<Idx>(rst, rows, has_data);
auto eset = ToEdgeSet<Idx>(rst);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
}
TEST(RowwiseTest, TestCOOSamplingUniform) {
_TestCOOSamplingUniform<int32_t, float>(true);
_TestCOOSamplingUniform<int64_t, float>(true);
_TestCOOSamplingUniform<int32_t, double>(true);
_TestCOOSamplingUniform<int64_t, double>(true);
_TestCOOSamplingUniform<int32_t, float>(false);
_TestCOOSamplingUniform<int64_t, float>(false);
_TestCOOSamplingUniform<int32_t, double>(false);
_TestCOOSamplingUniform<int64_t, double>(false);
}
template <typename Idx, typename FloatType>
void _TestCSRTopk(bool has_data) {
auto mat = CSR<Idx>(has_data);
FloatArray weight = NDArray::FromVector(
std::vector<FloatType>({.1, .0, -.1, .2, .5}));
// -.1, .2, .1, .0, .5
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
{
auto rst = CSRRowWiseTopk(mat, rows, 1, weight, true);
auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 2);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));
}
}
{
auto rst = CSRRowWiseTopk(mat, rows, 1, weight, false);
auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 2);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
}
TEST(RowwiseTest, TestCSRTopk) {
_TestCSRTopk<int32_t, float>(true);
_TestCSRTopk<int64_t, float>(true);
_TestCSRTopk<int32_t, double>(true);
_TestCSRTopk<int64_t, double>(true);
_TestCSRTopk<int32_t, float>(false);
_TestCSRTopk<int64_t, float>(false);
_TestCSRTopk<int32_t, double>(false);
_TestCSRTopk<int64_t, double>(false);
}
template <typename Idx, typename FloatType>
void _TestCOOTopk(bool has_data) {
auto mat = COO<Idx>(has_data);
FloatArray weight = NDArray::FromVector(
std::vector<FloatType>({.1, .0, -.1, .2, .5}));
// -.1, .2, .1, .0, .5
IdArray rows = NDArray::FromVector(std::vector<Idx>({0, 3}));
{
auto rst = COORowWiseTopk(mat, rows, 1, weight, true);
auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 2);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 2)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 1)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 1)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 2, 3)));
}
}
{
auto rst = COORowWiseTopk(mat, rows, 1, weight, false);
auto eset = ToEdgeSet<Idx>(rst);
ASSERT_EQ(eset.size(), 2);
if (has_data) {
ASSERT_TRUE(eset.count(std::make_tuple(0, 1, 3)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
} else {
ASSERT_TRUE(eset.count(std::make_tuple(0, 0, 0)));
ASSERT_TRUE(eset.count(std::make_tuple(3, 3, 4)));
}
}
}
TEST(RowwiseTest, TestCOOTopk) {
_TestCOOTopk<int32_t, float>(true);
_TestCOOTopk<int64_t, float>(true);
_TestCOOTopk<int32_t, double>(true);
_TestCOOTopk<int64_t, double>(true);
_TestCOOTopk<int32_t, float>(false);
_TestCOOTopk<int64_t, float>(false);
_TestCOOTopk<int32_t, double>(false);
_TestCOOTopk<int64_t, double>(false);
}
...@@ -25,7 +25,7 @@ void _TestWithReplacement(RandomEngine *re) { ...@@ -25,7 +25,7 @@ void _TestWithReplacement(RandomEngine *re) {
FloatArray prob = NDArray::FromVector(_prob); FloatArray prob = NDArray::FromVector(_prob);
auto _check_given_sampler = [n_categories, n_rolls, &_prob]( auto _check_given_sampler = [n_categories, n_rolls, &_prob](
utils::BaseSampler<Idx, DType, true> *s) { utils::BaseSampler<Idx> *s) {
std::vector<Idx> counter(n_categories, 0); std::vector<Idx> counter(n_categories, 0);
for (Idx i = 0; i < n_rolls; ++i) { for (Idx i = 0; i < n_rolls; ++i) {
Idx dice = s->Draw(); Idx dice = s->Draw();
...@@ -74,7 +74,7 @@ void _TestWithoutReplacementOrder(RandomEngine *re) { ...@@ -74,7 +74,7 @@ void _TestWithoutReplacementOrder(RandomEngine *re) {
std::vector<Idx> ground_truth = {0, 3, 2, 1}; std::vector<Idx> ground_truth = {0, 3, 2, 1};
auto _check_given_sampler = [&ground_truth]( auto _check_given_sampler = [&ground_truth](
utils::BaseSampler<Idx, DType, false> *s) { utils::BaseSampler<Idx> *s) {
for (size_t i = 0; i < ground_truth.size(); ++i) { for (size_t i = 0; i < ground_truth.size(); ++i) {
Idx dice = s->Draw(); Idx dice = s->Draw();
ASSERT_EQ(dice, ground_truth[i]); ASSERT_EQ(dice, ground_truth[i]);
...@@ -110,7 +110,7 @@ void _TestWithoutReplacementUnique(RandomEngine *re) { ...@@ -110,7 +110,7 @@ void _TestWithoutReplacementUnique(RandomEngine *re) {
FloatArray likelihood = NDArray::FromVector(_likelihood); FloatArray likelihood = NDArray::FromVector(_likelihood);
auto _check_given_sampler = [N]( auto _check_given_sampler = [N](
utils::BaseSampler<Idx, DType, false> *s) { utils::BaseSampler<Idx> *s) {
std::vector<int> cnt(N, 0); std::vector<int> cnt(N, 0);
for (Idx i = 0; i < N; ++i) { for (Idx i = 0; i < N; ++i) {
Idx dice = s->Draw(); Idx dice = s->Draw();
...@@ -139,3 +139,91 @@ TEST(SampleUtilsTest, TestWithoutReplacementUnique) { ...@@ -139,3 +139,91 @@ TEST(SampleUtilsTest, TestWithoutReplacementUnique) {
re->SetSeed(42); re->SetSeed(42);
_TestWithoutReplacementUnique<int64_t, double>(re); _TestWithoutReplacementUnique<int64_t, double>(re);
}; };
template <typename Idx, typename DType>
void _TestChoice(RandomEngine* re) {
re->SetSeed(42);
std::vector<DType> prob_vec = {1., 0., 0., 0., 2., 2., 0., 0.};
FloatArray prob = FloatArray::FromVector(prob_vec);
{
for (int k = 0; k < 1000; ++k) {
Idx x = re->Choice<Idx>(prob);
ASSERT_TRUE(x == 0 || x == 4 || x == 5);
}
}
// num = 0
{
IdArray rst = re->Choice<Idx, DType>(0, prob, true);
ASSERT_EQ(rst->shape[0], 0);
}
// w/ replacement
{
IdArray rst = re->Choice<Idx, DType>(1000, prob, true);
ASSERT_EQ(rst->shape[0], 1000);
for (int64_t i = 0; i < 1000; ++i) {
Idx x = static_cast<Idx*>(rst->data)[i];
ASSERT_TRUE(x == 0 || x == 4 || x == 5);
}
}
// w/o replacement
{
IdArray rst = re->Choice<Idx, DType>(3, prob, false);
ASSERT_EQ(rst->shape[0], 3);
std::set<Idx> idxset;
for (int64_t i = 0; i < 3; ++i) {
Idx x = static_cast<Idx*>(rst->data)[i];
idxset.insert(x);
}
ASSERT_EQ(idxset.size(), 3);
ASSERT_EQ(idxset.count(0), 1);
ASSERT_EQ(idxset.count(4), 1);
ASSERT_EQ(idxset.count(5), 1);
}
}
TEST(RandomTest, TestChoice) {
RandomEngine* re = RandomEngine::ThreadLocal();
_TestChoice<int32_t, float>(re);
_TestChoice<int64_t, float>(re);
_TestChoice<int32_t, double>(re);
_TestChoice<int64_t, double>(re);
}
template <typename Idx>
void _TestUniformChoice(RandomEngine* re) {
re->SetSeed(42);
// num == 0
{
IdArray rst = re->UniformChoice<Idx>(0, 100, true);
ASSERT_EQ(rst->shape[0], 0);
}
// w/ replacement
{
IdArray rst = re->UniformChoice<Idx>(1000, 100, true);
ASSERT_EQ(rst->shape[0], 1000);
for (int64_t i = 0; i < 1000; ++i) {
Idx x = static_cast<Idx*>(rst->data)[i];
ASSERT_TRUE(x >= 0 && x < 100);
}
}
// w/o replacement
{
IdArray rst = re->UniformChoice<Idx>(99, 100, false);
ASSERT_EQ(rst->shape[0], 99);
std::set<Idx> idxset;
for (int64_t i = 0; i < 99; ++i) {
Idx x = static_cast<Idx*>(rst->data)[i];
ASSERT_TRUE(x >= 0 && x < 100);
idxset.insert(x);
}
ASSERT_EQ(idxset.size(), 99);
}
}
TEST(RandomTest, TestUniformChoice) {
RandomEngine* re = RandomEngine::ThreadLocal();
_TestUniformChoice<int32_t>(re);
_TestUniformChoice<int64_t>(re);
_TestUniformChoice<int32_t>(re);
_TestUniformChoice<int64_t>(re);
}
...@@ -13,7 +13,7 @@ using namespace dgl; ...@@ -13,7 +13,7 @@ using namespace dgl;
using namespace dgl::aten; using namespace dgl::aten;
using namespace dmlc; using namespace dmlc;
TEST(Serialize, UnitGraph) { TEST(Serialize, DISABLED_UnitGraph) {
aten::CSRMatrix csr_matrix; aten::CSRMatrix csr_matrix;
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
...@@ -32,9 +32,10 @@ TEST(Serialize, UnitGraph) { ...@@ -32,9 +32,10 @@ TEST(Serialize, UnitGraph) {
EXPECT_EQ(ug2->NumEdges(0), 4); EXPECT_EQ(ug2->NumEdges(0), 4);
EXPECT_EQ(ug2->FindEdge(0, 1).first, 2); EXPECT_EQ(ug2->FindEdge(0, 1).first, 2);
EXPECT_EQ(ug2->FindEdge(0, 1).second, 6); EXPECT_EQ(ug2->FindEdge(0, 1).second, 6);
delete ug2;
} }
TEST(Serialize, ImmutableGraph) { TEST(Serialize, DISABLED_ImmutableGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst); auto gptr = ImmutableGraph::CreateFromCOO(10, src, dst);
...@@ -52,9 +53,10 @@ TEST(Serialize, ImmutableGraph) { ...@@ -52,9 +53,10 @@ TEST(Serialize, ImmutableGraph) {
EXPECT_EQ(rptr_read->NumVertices(), 10); EXPECT_EQ(rptr_read->NumVertices(), 10);
EXPECT_EQ(rptr_read->FindEdge(2).first, 5); EXPECT_EQ(rptr_read->FindEdge(2).first, 5);
EXPECT_EQ(rptr_read->FindEdge(2).second, 2); EXPECT_EQ(rptr_read->FindEdge(2).second, 2);
delete rptr_read;
} }
TEST(Serialize, HeteroGraph) { TEST(Serialize, DISABLED_HeteroGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst); auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
...@@ -78,4 +80,6 @@ TEST(Serialize, HeteroGraph) { ...@@ -78,4 +80,6 @@ TEST(Serialize, HeteroGraph) {
static_cast<dmlc::Stream*>(&ofs)->Read(gptr); static_cast<dmlc::Stream*>(&ofs)->Read(gptr);
EXPECT_EQ(gptr->NumVertices(0), 9); EXPECT_EQ(gptr->NumVertices(0), 9);
EXPECT_EQ(gptr->NumVertices(1), 8); EXPECT_EQ(gptr->NumVertices(1), 8);
} delete hrptr;
\ No newline at end of file delete gptr;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment