Unverified Commit 5dd35580 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Feature] Improve sampling speed; Better pickle/unpickle; other fixes (#1299)

* improve performance of sample_neighbors

* some more improve

* test script

* benchmarks

* multi process

* update more tests

* WIP

* adding two API for state saving

* add create from state

* upd test

* missing file

* wip: pickle/unpickle

* more c apis

* find the problem of empty data array

* add null array; pickling speed is bad

* still bad perf

* still bad perf

* wip

* fix the pickle speed test; now everything looks good

* minor fix

* bugfix

* some lint fix

* address comments

* more fix

* fix lint

* add utest for random.choice

* add utest for dgl.rand_graph

* fix cpp utests

* try fix ci

* fix bug in TF backend

* upd choice docstring

* address comments

* upd

* try fix compile

* add comment
parent 00ba4094
...@@ -19,14 +19,6 @@ inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) { ...@@ -19,14 +19,6 @@ inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes; return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
} }
/*! \brief Output the string representation of device context.*/
inline std::ostream& operator << (std::ostream& os, const DLDataType& ty) {
return os <<
"code=" << static_cast<int>(ty.code) <<
",bits=" << static_cast<int>(ty.bits) <<
"lanes=" << static_cast<int>(ty.lanes);
}
/*! \brief Check whether two device contexts are the same.*/ /*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) { inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id; return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
......
/*!
* Copyright (c) 2020 by Contributors
* \file graph/creators.cc
* \brief Functions for constructing graphs.
*/
#include "./heterograph.h"
using namespace dgl::runtime;
namespace dgl {
// creator implementation
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
}
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCOO(
num_vtypes, num_src, num_dst, row, col, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format) {
auto unit_g = UnitGraph::CreateFromCSR(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, restrict_format);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
} // namespace dgl
This diff is collapsed.
/*!
* Copyright (c) 2020 by Contributors
* \file graph/heterograph_capi.cc
* \brief Heterograph CAPI bindings.
*/
#include "./heterograph.h"
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
///////////////////////// Unitgraph functions /////////////////////////
// XXX(minjie): Ideally, Unitgraph should be invisible to python side
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray row = args[3];
IdArray col = args[4];
SparseFormat restrict_format = ParseSparseFormat(args[5]);
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col, restrict_format);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
IdArray indptr = args[3];
IdArray indices = args[4];
IdArray edge_ids = args[5];
SparseFormat restrict_format = ParseSparseFormat(args[6]);
auto hgptr = CreateFromCSR(nvtypes, num_src, num_dst, indptr, indices, edge_ids,
restrict_format);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1];
std::vector<HeteroGraphPtr> rel_ptrs;
rel_ptrs.reserve(rel_graphs.size());
for (const auto& ref : rel_graphs) {
rel_ptrs.push_back(ref.sptr());
}
auto hgptr = CreateHeteroGraph(meta_graph.sptr(), rel_ptrs);
*rv = HeteroGraphRef(hgptr);
});
///////////////////////// HeteroGraph member functions /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = GraphRef(hg->meta_graph());
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
// Test if the heterograph is a unit graph. If so, return itself.
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg.sptr());
if (bg != nullptr)
*rv = bg;
else
*rv = HeteroGraphRef(hg->GetRelationGraph(etype));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> etypes = args[1];
std::vector<dgl_id_t> etypes_vec;
for (Value val : etypes) {
// (gq) have to decompose it into two statements because of a weird MSVC internal error
dgl_id_t id = val->data;
etypes_vec.push_back(id);
}
*rv = FlattenedHeteroGraphRef(hg->Flatten(etypes_vec));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
int64_t num = args[2];
hg->AddVertices(vtype, num);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t src = args[2];
dgl_id_t dst = args[3];
hg->AddEdge(etype, src, dst);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
IdArray dst = args[3];
hg->AddEdges(etype, src, dst);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
hg->Clear();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->DataType();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->Context();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->NumBits();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->IsMultigraph();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->IsReadonly();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
*rv = static_cast<int64_t>(hg->NumVertices(vtype));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
*rv = static_cast<int64_t>(hg->NumEdges(etype));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
dgl_id_t vid = args[2];
*rv = hg->HasVertex(vtype, vid);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
IdArray vids = args[2];
*rv = hg->HasVertices(vtype, vids);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t src = args[2];
dgl_id_t dst = args[3];
*rv = hg->HasEdgeBetween(etype, src, dst);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
IdArray dst = args[3];
*rv = hg->HasEdgesBetween(etype, src, dst);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t dst = args[2];
*rv = hg->Predecessors(etype, dst);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t src = args[2];
*rv = hg->Successors(etype, src);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t src = args[2];
dgl_id_t dst = args[3];
*rv = hg->EdgeId(etype, src, dst);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
IdArray dst = args[3];
const auto& ret = hg->EdgeIds(etype, src, dst);
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray eids = args[2];
const auto& ret = hg->FindEdges(etype, eids);
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t vid = args[2];
const auto& ret = hg->InEdges(etype, vid);
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray vids = args[2];
const auto& ret = hg->InEdges(etype, vids);
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t vid = args[2];
const auto& ret = hg->OutEdges(etype, vid);
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray vids = args[2];
const auto& ret = hg->OutEdges(etype, vids);
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
std::string order = args[2];
const auto& ret = hg->Edges(etype, order);
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t vid = args[2];
*rv = static_cast<int64_t>(hg->InDegree(etype, vid));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray vids = args[2];
*rv = hg->InDegrees(etype, vids);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t vid = args[2];
*rv = static_cast<int64_t>(hg->OutDegree(etype, vid));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray vids = args[2];
*rv = hg->OutDegrees(etype, vids);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
bool transpose = args[2];
std::string fmt = args[3];
*rv = ConvertNDArrayVectorToPackedFunc(
hg->GetAdj(etype, transpose, fmt));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> vids = args[1];
std::vector<IdArray> vid_vec;
vid_vec.reserve(vids.size());
for (Value val : vids) {
vid_vec.push_back(val->data);
}
std::shared_ptr<HeteroSubgraph> subg(
new HeteroSubgraph(hg->VertexSubgraph(vid_vec)));
*rv = HeteroSubgraphRef(subg);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> eids = args[1];
bool preserve_nodes = args[2];
std::vector<IdArray> eid_vec;
eid_vec.reserve(eids.size());
for (Value val : eids) {
eid_vec.push_back(val->data);
}
std::shared_ptr<HeteroSubgraph> subg(
new HeteroSubgraph(hg->EdgeSubgraph(eid_vec, preserve_nodes)));
*rv = HeteroSubgraphRef(subg);
});
///////////////////////// HeteroSubgraph members /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
*rv = HeteroGraphRef(subg->graph);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
List<Value> induced_verts;
for (IdArray arr : subg->induced_vertices) {
induced_verts.push_back(Value(MakeValue(arr)));
}
*rv = induced_verts;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
List<Value> induced_edges;
for (IdArray arr : subg->induced_edges) {
induced_edges.push_back(Value(MakeValue(arr)));
}
*rv = induced_edges;
});
///////////////////////// Global functions and algorithms /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int bits = args[1];
HeteroGraphPtr hg_new = UnitGraph::AsNumBits(hg.sptr(), bits);
*rv = HeteroGraphRef(hg_new);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int device_type = args[1];
int device_id = args[2];
DLContext ctx;
ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id;
HeteroGraphPtr hg_new = UnitGraph::CopyTo(hg.sptr(), ctx);
*rv = HeteroGraphRef(hg_new);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size());
for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr());
}
auto hgptr = DisjointUnionHeteroGraph(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2];
const auto& ret = DisjointPartitionHeteroBySizes(
hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);
List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr));
}
*rv = ret_list;
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLInSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);
*ret = InEdgeGraph(hg.sptr(), nodes);
*rv = HeteroGraphRef(ret);
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLOutSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
std::shared_ptr<HeteroSubgraph> ret(new HeteroSubgraph);
*ret = OutEdgeGraph(hg.sptr(), nodes);
*rv = HeteroGraphRef(ret);
});
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file graph/pickle.cc
* \brief Functions for pickle and unpickle a graph
*/
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include "./heterograph.h"
#include "../c_api_common.h"
using namespace dgl::runtime;
namespace dgl {
HeteroPickleStates HeteroPickle(HeteroGraphPtr graph) {
HeteroPickleStates states;
states.metagraph = graph->meta_graph();
states.adjs.resize(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
SparseFormat fmt = graph->SelectFormat(etype, SparseFormat::ANY);
states.adjs[etype] = std::make_shared<SparseMatrix>();
switch (fmt) {
case SparseFormat::COO:
*states.adjs[etype] = graph->GetCOOMatrix(etype).ToSparseMatrix();
break;
case SparseFormat::CSR:
case SparseFormat::CSC:
*states.adjs[etype] = graph->GetCSRMatrix(etype).ToSparseMatrix();
break;
default:
LOG(FATAL) << "Unsupported sparse format.";
}
}
return states;
}
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
const auto metagraph = states.metagraph;
CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
const SparseFormat fmt = static_cast<SparseFormat>(states.adjs[etype]->format);
switch (fmt) {
case SparseFormat::COO:
relgraphs[etype] = UnitGraph::CreateFromCOO(
num_vtypes, aten::COOMatrix(*states.adjs[etype]));
break;
case SparseFormat::CSR:
relgraphs[etype] = UnitGraph::CreateFromCSR(
num_vtypes, aten::CSRMatrix(*states.adjs[etype]));
break;
case SparseFormat::CSC:
default:
LOG(FATAL) << "Unsupported sparse format.";
}
}
return CreateHeteroGraph(metagraph, relgraphs);
}
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMetagraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0];
*rv = GraphRef(st->metagraph);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetAdjs")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef st = args[0];
std::vector<SparseMatrixRef> refs(st->adjs.begin(), st->adjs.end());
*rv = List<SparseMatrixRef>(refs);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0];
List<SparseMatrixRef> adjs = args[1];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
st->metagraph = metagraph.sptr();
st->adjs.reserve(adjs.size());
for (const auto& ref : adjs)
st->adjs.push_back(ref.sptr());
*rv = HeteroPickleStatesRef(st);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef ref = args[0];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
*st = HeteroPickle(ref.sptr());
*rv = HeteroPickleStatesRef(st);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroPickleStatesRef ref = args[0];
HeteroGraphPtr graph = HeteroUnpickle(*ref.sptr());
*rv = HeteroGraphRef(graph);
});
} // namespace dgl
...@@ -931,7 +931,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling") ...@@ -931,7 +931,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
{ {
const FloatType *prob; const FloatType *prob;
if (probability->ndim == 1 && probability->shape[0] == 0) { if (aten::IsNullArray(probability)) {
prob = nullptr; prob = nullptr;
} else { } else {
CHECK(probability->shape[0] == gptr->NumEdges()) CHECK(probability->shape[0] == gptr->NumEdges())
...@@ -1237,7 +1237,7 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg, ...@@ -1237,7 +1237,7 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
} }
// TODO(zhengda) we should provide an array of 1s if exclude_positive // TODO(zhengda) we should provide an array of 1s if exclude_positive
if (check_false_neg) { if (check_false_neg) {
if (relations_->shape[0] == 0) { if (aten::IsNullArray(relations_)) {
neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid); neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);
} else { } else {
neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst, neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst,
...@@ -1386,7 +1386,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub ...@@ -1386,7 +1386,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
neg_subg.tail_nid = aten::VecToIdArray(Global2Local(global_neg_vids, neg_map)); neg_subg.tail_nid = aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));
} }
if (check_false_neg) { if (check_false_neg) {
if (relations_->shape[0] == 0) { if (aten::IsNullArray(relations_)) {
neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid); neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);
} else { } else {
neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst, neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst,
......
...@@ -41,14 +41,14 @@ HeteroSubgraph SampleNeighbors( ...@@ -41,14 +41,14 @@ HeteroSubgraph SampleNeighbors(
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype]; const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0]; const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0) { if (num_nodes == 0 || fanouts[etype] == 0) {
// No node provided in the type, create a placeholder relation graph // Nothing to sample for this etype, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty( subrels[etype] = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(), hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype), hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype), hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context()); hg->DataType(), hg->Context());
induced_edges[etype] = IdArray::Empty({0}, hg->DataType(), hg->Context()); induced_edges[etype] = aten::NullArray();
} else { } else {
// sample from one relation graph // sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::CSR : SparseFormat::CSC; auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::CSR : SparseFormat::CSC;
...@@ -81,11 +81,7 @@ HeteroSubgraph SampleNeighbors( ...@@ -81,11 +81,7 @@ HeteroSubgraph SampleNeighbors(
} }
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo); hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo);
if (sampled_coo.data.defined()) {
induced_edges[etype] = sampled_coo.data; induced_edges[etype] = sampled_coo.data;
} else {
induced_edges[etype] = IdArray::Empty({0}, hg->DataType(), hg->Context());
}
} }
} }
...@@ -119,14 +115,14 @@ HeteroSubgraph SampleNeighborsTopk( ...@@ -119,14 +115,14 @@ HeteroSubgraph SampleNeighborsTopk(
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype]; const IdArray nodes_ntype = nodes[(dir == EdgeDir::kOut)? src_vtype : dst_vtype];
const int64_t num_nodes = nodes_ntype->shape[0]; const int64_t num_nodes = nodes_ntype->shape[0];
if (num_nodes == 0) { if (num_nodes == 0 || k[etype] == 0) {
// No node provided in the type, create a placeholder relation graph // Nothing to sample for this etype, create a placeholder relation graph
subrels[etype] = UnitGraph::Empty( subrels[etype] = UnitGraph::Empty(
hg->GetRelationGraph(etype)->NumVertexTypes(), hg->GetRelationGraph(etype)->NumVertexTypes(),
hg->NumVertices(src_vtype), hg->NumVertices(src_vtype),
hg->NumVertices(dst_vtype), hg->NumVertices(dst_vtype),
hg->DataType(), hg->Context()); hg->DataType(), hg->Context());
induced_edges[etype] = IdArray::Empty({0}, hg->DataType(), hg->Context()); induced_edges[etype] = aten::NullArray();
} else { } else {
// sample from one relation graph // sample from one relation graph
auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::CSR : SparseFormat::CSC; auto req_fmt = (dir == EdgeDir::kOut)? SparseFormat::CSR : SparseFormat::CSC;
...@@ -159,11 +155,7 @@ HeteroSubgraph SampleNeighborsTopk( ...@@ -159,11 +155,7 @@ HeteroSubgraph SampleNeighborsTopk(
} }
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo); hg->GetRelationGraph(etype)->NumVertexTypes(), sampled_coo);
if (sampled_coo.data.defined()) {
induced_edges[etype] = sampled_coo.data; induced_edges[etype] = sampled_coo.data;
} else {
induced_edges[etype] = IdArray::Empty({0}, hg->DataType(), hg->Context());
}
} }
} }
......
...@@ -76,7 +76,7 @@ std::pair<dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -76,7 +76,7 @@ std::pair<dgl_id_t, bool> MetapathRandomWalkStep(
FloatArray prob_etype = prob[etype]; FloatArray prob_etype = prob[etype];
IdxType idx; IdxType idx;
if (prob_etype->shape[0] == 0) { if (IsNullArray(prob_etype)) {
// empty probability array; assume uniform // empty probability array; assume uniform
idx = RandomEngine::ThreadLocal()->RandInt(size); idx = RandomEngine::ThreadLocal()->RandInt(size);
} else { } else {
......
/*!
* Copyright (c) 2020 by Contributors
* \file graph/subgraph.cc
* \brief Functions for extracting subgraphs.
*/
#include "./heterograph.h"
using namespace dgl::runtime;
namespace dgl {
HeteroSubgraph InEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
auto relgraph = graph->GetRelationGraph(etype);
if (aten::IsNullArray(vids[dst_vtype])) {
// create a placeholder graph
subrels[etype] = UnitGraph::Empty(
relgraph->NumVertexTypes(),
graph->NumVertices(src_vtype),
graph->NumVertices(dst_vtype),
graph->DataType(), graph->Context());
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else {
const auto& earr = graph->InEdges(etype, {vids[dst_vtype]});
subrels[etype] = UnitGraph::CreateFromCOO(
relgraph->NumVertexTypes(),
graph->NumVertices(src_vtype),
graph->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id;
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels);
ret.induced_edges = std::move(induced_edges);
return ret;
}
HeteroSubgraph OutEdgeGraph(const HeteroGraphPtr graph, const std::vector<IdArray>& vids) {
CHECK_EQ(vids.size(), graph->NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
std::vector<HeteroGraphPtr> subrels(graph->NumEdgeTypes());
std::vector<IdArray> induced_edges(graph->NumEdgeTypes());
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
auto relgraph = graph->GetRelationGraph(etype);
if (aten::IsNullArray(vids[src_vtype])) {
// create a placeholder graph
subrels[etype] = UnitGraph::Empty(
relgraph->NumVertexTypes(),
graph->NumVertices(src_vtype),
graph->NumVertices(dst_vtype),
graph->DataType(), graph->Context());
induced_edges[etype] = IdArray::Empty({0}, graph->DataType(), graph->Context());
} else {
const auto& earr = graph->OutEdges(etype, {vids[src_vtype]});
subrels[etype] = UnitGraph::CreateFromCOO(
relgraph->NumVertexTypes(),
graph->NumVertices(src_vtype),
graph->NumVertices(dst_vtype),
earr.src,
earr.dst);
induced_edges[etype] = earr.id;
}
}
HeteroSubgraph ret;
ret.graph = CreateHeteroGraph(graph->meta_graph(), subrels);
ret.induced_edges = std::move(induced_edges);
return ret;
}
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file graph/transform/union_partition.cc
* \brief Functions for partition, union multiple graphs.
*/
#include "../heterograph.h"
using namespace dgl::runtime;
namespace dgl {
HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
dgl_id_t src_offset = 0, dst_offset = 0;
std::vector<dgl_id_t> result_src, result_dst;
// Loop over all graphs
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
EdgeArray edges = cg->Edges(etype);
size_t num_edges = cg->NumEdges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
// Loop over all edges
for (size_t j = 0; j < num_edges; ++j) {
// TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[j] + src_offset);
result_dst.push_back(edges_dst_data[j] + dst_offset);
}
// Update offsets
src_offset += cg->NumVertices(src_vtype);
dst_offset += cg->NumVertices(dst_vtype);
}
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
src_offset,
dst_offset,
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
rel_graphs[etype] = rgptr;
}
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs));
}
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
// Sanity check for vertex sizes
const uint64_t len_vertex_sizes = vertex_sizes->shape[0];
const uint64_t* vertex_sizes_data = static_cast<uint64_t*>(vertex_sizes->data);
const uint64_t num_vertex_types = meta_graph->NumVertices();
const uint64_t batch_size = len_vertex_sizes / num_vertex_types;
// Map vertex type to the corresponding node cum sum
std::vector<std::vector<uint64_t>> vertex_cumsum;
vertex_cumsum.resize(num_vertex_types);
// Loop over all vertex types
for (uint64_t vtype = 0; vtype < num_vertex_types; ++vtype) {
vertex_cumsum[vtype].push_back(0);
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of vertices in the batch for all types
vertex_cumsum[vtype].push_back(
vertex_cumsum[vtype][g] + vertex_sizes_data[vtype * batch_size + g]);
}
CHECK_EQ(vertex_cumsum[vtype][batch_size], batched_graph->NumVertices(vtype))
<< "Sum of the given sizes must equal to the number of nodes for type " << vtype;
}
// Sanity check for edge sizes
const uint64_t* edge_sizes_data = static_cast<uint64_t*>(edge_sizes->data);
const uint64_t num_edge_types = meta_graph->NumEdges();
// Map edge type to the corresponding edge cum sum
std::vector<std::vector<uint64_t>> edge_cumsum;
edge_cumsum.resize(num_edge_types);
// Loop over all edge types
for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
edge_cumsum[etype].push_back(0);
for (uint64_t g = 0; g < batch_size; ++g) {
// We've flattened the number of edges in the batch for all types
edge_cumsum[etype].push_back(
edge_cumsum[etype][g] + edge_sizes_data[etype * batch_size + g]);
}
CHECK_EQ(edge_cumsum[etype][batch_size], batched_graph->NumEdges(etype))
<< "Sum of the given sizes must equal to the number of edges for type " << etype;
}
// Construct relation graphs for unbatched graphs
std::vector<std::vector<HeteroGraphPtr>> rel_graphs;
rel_graphs.resize(batch_size);
// Loop over all edge types
for (uint64_t etype = 0; etype < num_edge_types; ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
EdgeArray edges = batched_graph->Edges(etype);
const dgl_id_t* edges_src_data = static_cast<const dgl_id_t*>(edges.src->data);
const dgl_id_t* edges_dst_data = static_cast<const dgl_id_t*>(edges.dst->data);
// Loop over all graphs to be unbatched
for (uint64_t g = 0; g < batch_size; ++g) {
std::vector<dgl_id_t> result_src, result_dst;
// Loop over the chunk of edges for the specified graph and edge type
for (uint64_t e = edge_cumsum[etype][g]; e < edge_cumsum[etype][g + 1]; ++e) {
// TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[e] - vertex_cumsum[src_vtype][g]);
result_dst.push_back(edges_dst_data[e] - vertex_cumsum[dst_vtype][g]);
}
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
vertex_sizes_data[src_vtype * batch_size + g],
vertex_sizes_data[dst_vtype * batch_size + g],
aten::VecToIdArray(result_src),
aten::VecToIdArray(result_dst));
rel_graphs[g].push_back(rgptr);
}
}
std::vector<HeteroGraphPtr> rst;
for (uint64_t g = 0; g < batch_size; ++g) {
rst.push_back(HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs[g])));
}
return rst;
}
} // namespace dgl
...@@ -83,7 +83,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -83,7 +83,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
: BaseHeteroGraph(metagraph), adj_(coo) { : BaseHeteroGraph(metagraph), adj_(coo) {
// Data index should not be inherited. Edges in COO format are always // Data index should not be inherited. Edges in COO format are always
// assigned ids from 0 to num_edges - 1. // assigned ids from 0 to num_edges - 1.
adj_.data = IdArray(); adj_.data = aten::NullArray();
} }
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
namespace dgl { namespace dgl {
class HeteroGraph;
class UnitGraph; class UnitGraph;
typedef std::shared_ptr<UnitGraph> UnitGraphPtr; typedef std::shared_ptr<UnitGraph> UnitGraphPtr;
...@@ -212,6 +213,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -212,6 +213,7 @@ class UnitGraph : public BaseHeteroGraph {
private: private:
friend class Serializer; friend class Serializer;
friend class HeteroGraph;
/*! /*!
* \brief constructor * \brief constructor
...@@ -227,8 +229,10 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -227,8 +229,10 @@ class UnitGraph : public BaseHeteroGraph {
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
/*! /*!
* \return Return the given format. Perform format conversion if requested format does * \brief Return the graph in the given format. Perform format conversion if the
* not exist. * requested format does not exist.
*
* \return A graph in the requested format.
*/ */
HeteroGraphPtr GetFormat(SparseFormat format) const; HeteroGraphPtr GetFormat(SparseFormat format) const;
......
...@@ -181,7 +181,7 @@ inline void CheckCtx( ...@@ -181,7 +181,7 @@ inline void CheckCtx(
const std::vector<NDArray>& arrays, const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) { const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) { for (size_t i = 0; i < arrays.size(); ++i) {
if (utils::IsNoneArray(arrays[i])) if (aten::IsNullArray(arrays[i]))
continue; continue;
CHECK_EQ(ctx, arrays[i]->ctx) CHECK_EQ(ctx, arrays[i]->ctx)
<< "Expected device context " << ctx << ". But got " << "Expected device context " << ctx << ". But got "
...@@ -195,7 +195,7 @@ inline void CheckIdArray( ...@@ -195,7 +195,7 @@ inline void CheckIdArray(
const std::vector<NDArray>& arrays, const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) { const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) { for (size_t i = 0; i < arrays.size(); ++i) {
if (utils::IsNoneArray(arrays[i])) if (aten::IsNullArray(arrays[i]))
continue; continue;
CHECK(arrays[i]->dtype.code == kDLInt); CHECK(arrays[i]->dtype.code == kDLInt);
CHECK_EQ(arrays[i]->ndim, 1); CHECK_EQ(arrays[i]->ndim, 1);
...@@ -415,14 +415,14 @@ void BackwardLhsBinaryOpReduce( ...@@ -415,14 +415,14 @@ void BackwardLhsBinaryOpReduce(
lhs, rhs, lhs, rhs,
lhs_mapping, rhs_mapping, out_mapping, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
grad_lhs_data, utils::NoneArray()); grad_lhs_data, aten::NullArray());
} else { } else {
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceImpl, DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceImpl,
reducer, op, graph, reducer, op, graph,
lhs, rhs, lhs, rhs,
lhs_mapping, rhs_mapping, out_mapping, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
grad_lhs_data, utils::NoneArray()); grad_lhs_data, aten::NullArray());
} }
} }
} }
...@@ -491,14 +491,14 @@ void BackwardRhsBinaryOpReduce( ...@@ -491,14 +491,14 @@ void BackwardRhsBinaryOpReduce(
lhs, rhs, lhs, rhs,
lhs_mapping, rhs_mapping, out_mapping, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
utils::NoneArray(), grad_rhs_data); aten::NullArray(), grad_rhs_data);
} else { } else {
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceImpl, DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceImpl,
reducer, op, graph, reducer, op, graph,
lhs, rhs, lhs, rhs,
lhs_mapping, rhs_mapping, out_mapping, lhs_mapping, rhs_mapping, out_mapping,
lhs_data, rhs_data, out_data, grad_out_data, lhs_data, rhs_data, out_data, grad_out_data,
utils::NoneArray(), grad_rhs_data); aten::NullArray(), grad_rhs_data);
} }
} }
} }
...@@ -548,8 +548,8 @@ void CopyReduce( ...@@ -548,8 +548,8 @@ void CopyReduce(
DGL_XPU_SWITCH(ctx.device_type, BinaryReduceImpl, DGL_XPU_SWITCH(ctx.device_type, BinaryReduceImpl,
reducer, binary_op::kUseLhs, graph, reducer, binary_op::kUseLhs, graph,
target, binary_op::kNone, target, binary_op::kNone,
in_data, utils::NoneArray(), out_data, in_data, aten::NullArray(), out_data,
in_mapping, utils::NoneArray(), out_mapping); in_mapping, aten::NullArray(), out_mapping);
} }
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelCopyReduce")
...@@ -588,16 +588,16 @@ void BackwardCopyReduce( ...@@ -588,16 +588,16 @@ void BackwardCopyReduce(
CheckIdArray(graph.NumBits(), CheckIdArray(graph.NumBits(),
{in_mapping, out_mapping}, {in_mapping, out_mapping},
{"in_mapping", "out_mapping"}); {"in_mapping", "out_mapping"});
if (!utils::IsNoneArray(out_mapping)) { if (!aten::IsNullArray(out_mapping)) {
CHECK_EQ(ctx, out_mapping->ctx) << "Expected device context " << ctx CHECK_EQ(ctx, out_mapping->ctx) << "Expected device context " << ctx
<< ". But got " << out_mapping->ctx << " for rhs_data."; << ". But got " << out_mapping->ctx << " for rhs_data.";
} }
DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceImpl, DGL_XPU_SWITCH(ctx.device_type, BackwardBinaryReduceImpl,
reducer, binary_op::kUseLhs, graph, reducer, binary_op::kUseLhs, graph,
target, binary_op::kNone, target, binary_op::kNone,
in_mapping, utils::NoneArray(), out_mapping, in_mapping, aten::NullArray(), out_mapping,
in_data, utils::NoneArray(), out_data, grad_out_data, in_data, aten::NullArray(), out_data, grad_out_data,
grad_in_data, utils::NoneArray()); grad_in_data, aten::NullArray());
} }
DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce") DGL_REGISTER_GLOBAL("kernel._CAPI_DGLKernelBackwardCopyReduce")
......
...@@ -39,13 +39,13 @@ GData<Idx, DType> AllocGData(const std::string& op, ...@@ -39,13 +39,13 @@ GData<Idx, DType> AllocGData(const std::string& op,
gdata.lhs_data = static_cast<DType*>(lhs_data->data); gdata.lhs_data = static_cast<DType*>(lhs_data->data);
gdata.rhs_data = static_cast<DType*>(rhs_data->data); gdata.rhs_data = static_cast<DType*>(rhs_data->data);
gdata.out_data = static_cast<DType*>(out_data->data); gdata.out_data = static_cast<DType*>(out_data->data);
if (!utils::IsNoneArray(lhs_mapping)) { if (!aten::IsNullArray(lhs_mapping)) {
gdata.lhs_mapping = static_cast<Idx*>(lhs_mapping->data); gdata.lhs_mapping = static_cast<Idx*>(lhs_mapping->data);
} }
if (!utils::IsNoneArray(rhs_mapping)) { if (!aten::IsNullArray(rhs_mapping)) {
gdata.rhs_mapping = static_cast<Idx*>(rhs_mapping->data); gdata.rhs_mapping = static_cast<Idx*>(rhs_mapping->data);
} }
if (!utils::IsNoneArray(out_mapping)) { if (!aten::IsNullArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data); gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
} }
...@@ -130,25 +130,25 @@ BackwardGData<Idx, DType> AllocBackwardGData( ...@@ -130,25 +130,25 @@ BackwardGData<Idx, DType> AllocBackwardGData(
gdata.rhs_data = static_cast<DType*>(rhs_data->data); gdata.rhs_data = static_cast<DType*>(rhs_data->data);
gdata.out_data = static_cast<DType*>(out_data->data); gdata.out_data = static_cast<DType*>(out_data->data);
gdata.grad_out_data = static_cast<DType*>(grad_out_data->data); gdata.grad_out_data = static_cast<DType*>(grad_out_data->data);
if (!utils::IsNoneArray(grad_lhs_data)) { if (!aten::IsNullArray(grad_lhs_data)) {
gdata.grad_lhs_data = static_cast<DType*>(grad_lhs_data->data); gdata.grad_lhs_data = static_cast<DType*>(grad_lhs_data->data);
// fill out data with zero values // fill out data with zero values
utils::Fill<XPU>(ctx, gdata.grad_lhs_data, utils::NElements(grad_lhs_data), utils::Fill<XPU>(ctx, gdata.grad_lhs_data, utils::NElements(grad_lhs_data),
static_cast<DType>(0)); static_cast<DType>(0));
} }
if (!utils::IsNoneArray(grad_rhs_data)) { if (!aten::IsNullArray(grad_rhs_data)) {
gdata.grad_rhs_data = static_cast<DType*>(grad_rhs_data->data); gdata.grad_rhs_data = static_cast<DType*>(grad_rhs_data->data);
// fill out data with zero values // fill out data with zero values
utils::Fill<XPU>(ctx, gdata.grad_rhs_data, utils::NElements(grad_rhs_data), utils::Fill<XPU>(ctx, gdata.grad_rhs_data, utils::NElements(grad_rhs_data),
static_cast<DType>(0)); static_cast<DType>(0));
} }
if (!utils::IsNoneArray(lhs_mapping)) { if (!aten::IsNullArray(lhs_mapping)) {
gdata.lhs_mapping = static_cast<Idx*>(lhs_mapping->data); gdata.lhs_mapping = static_cast<Idx*>(lhs_mapping->data);
} }
if (!utils::IsNoneArray(rhs_mapping)) { if (!aten::IsNullArray(rhs_mapping)) {
gdata.rhs_mapping = static_cast<Idx*>(rhs_mapping->data); gdata.rhs_mapping = static_cast<Idx*>(rhs_mapping->data);
} }
if (!utils::IsNoneArray(out_mapping)) { if (!aten::IsNullArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data); gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
} }
...@@ -194,8 +194,8 @@ void BackwardBinaryReduceImpl( ...@@ -194,8 +194,8 @@ void BackwardBinaryReduceImpl(
#endif #endif
const DLDataType& dtype = out_data->dtype; const DLDataType& dtype = out_data->dtype;
const bool req_lhs = !utils::IsNoneArray(grad_lhs_data); const bool req_lhs = !aten::IsNullArray(grad_lhs_data);
const bool req_rhs = !utils::IsNoneArray(grad_rhs_data); const bool req_rhs = !aten::IsNullArray(grad_rhs_data);
const auto bits = graph.NumBits(); const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) { if (reducer == binary_op::kReduceMean) {
...@@ -247,13 +247,13 @@ BcastGData<NDim, Idx, DType> AllocBcastGData( ...@@ -247,13 +247,13 @@ BcastGData<NDim, Idx, DType> AllocBcastGData(
gdata.lhs_data = static_cast<DType*>(lhs_data->data); gdata.lhs_data = static_cast<DType*>(lhs_data->data);
gdata.rhs_data = static_cast<DType*>(rhs_data->data); gdata.rhs_data = static_cast<DType*>(rhs_data->data);
gdata.out_data = static_cast<DType*>(out_data->data); gdata.out_data = static_cast<DType*>(out_data->data);
if (!utils::IsNoneArray(lhs_mapping)) { if (!aten::IsNullArray(lhs_mapping)) {
gdata.lhs_mapping = static_cast<Idx*>(lhs_mapping->data); gdata.lhs_mapping = static_cast<Idx*>(lhs_mapping->data);
} }
if (!utils::IsNoneArray(rhs_mapping)) { if (!aten::IsNullArray(rhs_mapping)) {
gdata.rhs_mapping = static_cast<Idx*>(rhs_mapping->data); gdata.rhs_mapping = static_cast<Idx*>(rhs_mapping->data);
} }
if (!utils::IsNoneArray(out_mapping)) { if (!aten::IsNullArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data); gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
} }
gdata.data_len = info.data_len; gdata.data_len = info.data_len;
...@@ -344,13 +344,13 @@ BackwardBcastGData<NDim, Idx, DType> AllocBackwardBcastGData( ...@@ -344,13 +344,13 @@ BackwardBcastGData<NDim, Idx, DType> AllocBackwardBcastGData(
std::copy(info.out_shape.begin(), info.out_shape.end(), gdata.out_shape); std::copy(info.out_shape.begin(), info.out_shape.end(), gdata.out_shape);
std::copy(info.out_stride.begin(), info.out_stride.end(), gdata.out_stride); std::copy(info.out_stride.begin(), info.out_stride.end(), gdata.out_stride);
// mappings // mappings
if (!utils::IsNoneArray(lhs_mapping)) { if (!aten::IsNullArray(lhs_mapping)) {
gdata.lhs_mapping = static_cast<Idx*>(lhs_mapping->data); gdata.lhs_mapping = static_cast<Idx*>(lhs_mapping->data);
} }
if (!utils::IsNoneArray(rhs_mapping)) { if (!aten::IsNullArray(rhs_mapping)) {
gdata.rhs_mapping = static_cast<Idx*>(rhs_mapping->data); gdata.rhs_mapping = static_cast<Idx*>(rhs_mapping->data);
} }
if (!utils::IsNoneArray(out_mapping)) { if (!aten::IsNullArray(out_mapping)) {
gdata.out_mapping = static_cast<Idx*>(out_mapping->data); gdata.out_mapping = static_cast<Idx*>(out_mapping->data);
} }
gdata.data_len = info.data_len; gdata.data_len = info.data_len;
...@@ -360,13 +360,13 @@ BackwardBcastGData<NDim, Idx, DType> AllocBackwardBcastGData( ...@@ -360,13 +360,13 @@ BackwardBcastGData<NDim, Idx, DType> AllocBackwardBcastGData(
gdata.rhs_data = static_cast<DType*>(rhs->data); gdata.rhs_data = static_cast<DType*>(rhs->data);
gdata.out_data = static_cast<DType*>(out->data); gdata.out_data = static_cast<DType*>(out->data);
gdata.grad_out_data = static_cast<DType*>(grad_out->data); gdata.grad_out_data = static_cast<DType*>(grad_out->data);
if (!utils::IsNoneArray(grad_lhs)) { if (!aten::IsNullArray(grad_lhs)) {
gdata.grad_lhs_data = static_cast<DType*>(grad_lhs->data); gdata.grad_lhs_data = static_cast<DType*>(grad_lhs->data);
// fill out data with zero values // fill out data with zero values
utils::Fill<XPU>(ctx, gdata.grad_lhs_data, utils::NElements(grad_lhs), utils::Fill<XPU>(ctx, gdata.grad_lhs_data, utils::NElements(grad_lhs),
static_cast<DType>(0)); static_cast<DType>(0));
} }
if (!utils::IsNoneArray(grad_rhs)) { if (!aten::IsNullArray(grad_rhs)) {
gdata.grad_rhs_data = static_cast<DType*>(grad_rhs->data); gdata.grad_rhs_data = static_cast<DType*>(grad_rhs->data);
// fill out data with zero values // fill out data with zero values
utils::Fill<XPU>(ctx, gdata.grad_rhs_data, utils::NElements(grad_rhs), utils::Fill<XPU>(ctx, gdata.grad_rhs_data, utils::NElements(grad_rhs),
...@@ -405,8 +405,8 @@ void BackwardBinaryReduceBcastImpl( ...@@ -405,8 +405,8 @@ void BackwardBinaryReduceBcastImpl(
const DLDataType& dtype = out->dtype; const DLDataType& dtype = out->dtype;
const int bcast_ndim = info.out_shape.size(); const int bcast_ndim = info.out_shape.size();
const bool req_lhs = !utils::IsNoneArray(grad_lhs); const bool req_lhs = !aten::IsNullArray(grad_lhs);
const bool req_rhs = !utils::IsNoneArray(grad_rhs); const bool req_rhs = !aten::IsNullArray(grad_rhs);
const auto bits = graph.NumBits(); const auto bits = graph.NumBits();
if (reducer == binary_op::kReduceMean) { if (reducer == binary_op::kReduceMean) {
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file kernel/utils.cc * \file kernel/utils.cc
* \brief Kernel utilities * \brief Kernel utilities
*/ */
#include <dgl/array.h>
#include <vector> #include <vector>
#include <string> #include <string>
...@@ -30,7 +31,7 @@ int64_t ComputeXLength(runtime::NDArray feat_array) { ...@@ -30,7 +31,7 @@ int64_t ComputeXLength(runtime::NDArray feat_array) {
} }
int64_t NElements(const runtime::NDArray& array) { int64_t NElements(const runtime::NDArray& array) {
if (IsNoneArray(array)) { if (aten::IsNullArray(array)) {
return 0; return 0;
} else { } else {
int64_t ret = 1; int64_t ret = 1;
......
...@@ -17,16 +17,6 @@ namespace dgl { ...@@ -17,16 +17,6 @@ namespace dgl {
namespace kernel { namespace kernel {
namespace utils { namespace utils {
/* !\brief Return an NDArray that represents none value. */
inline runtime::NDArray NoneArray() {
return runtime::NDArray::Empty({}, DLDataType{kDLInt, 32, 1}, DLContext{kDLCPU, 0});
}
/* !\brief Return true if the NDArray is none. */
inline bool IsNoneArray(runtime::NDArray array) {
return array->ndim == 0;
}
/* /*
* !\brief Find number of threads is smaller than dim and max_nthrs * !\brief Find number of threads is smaller than dim and max_nthrs
* and is also the power of two. * and is also the power of two.
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <vector>
#include <numeric>
#include "sample_utils.h" #include "sample_utils.h"
namespace dgl { namespace dgl {
...@@ -27,63 +28,71 @@ template int64_t RandomEngine::Choice<int64_t>(FloatArray); ...@@ -27,63 +28,71 @@ template int64_t RandomEngine::Choice<int64_t>(FloatArray);
template<typename IdxType, typename FloatType> template<typename IdxType, typename FloatType>
IdArray RandomEngine::Choice(int64_t num, FloatArray prob, bool replace) { void RandomEngine::Choice(IdxType num, FloatArray prob, IdxType* out, bool replace) {
const int64_t N = prob->shape[0]; const IdxType N = prob->shape[0];
if (!replace) if (!replace)
CHECK_LE(num, N) << "Cannot take more sample than population when 'replace=false'"; CHECK_LE(num, N) << "Cannot take more sample than population when 'replace=false'";
if (num == N && !replace) if (num == N && !replace)
return aten::Range(0, N, sizeof(IdxType) * 8, DLContext{kDLCPU, 0}); std::iota(out, out + num, 0);
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
IdxType* ret_data = static_cast<IdxType*>(ret->data);
utils::BaseSampler<IdxType>* sampler = nullptr; utils::BaseSampler<IdxType>* sampler = nullptr;
if (replace) { if (replace) {
sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob); sampler = new utils::TreeSampler<IdxType, FloatType, true>(this, prob);
} else { } else {
sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob); sampler = new utils::TreeSampler<IdxType, FloatType, false>(this, prob);
} }
for (int64_t i = 0; i < num; ++i) for (IdxType i = 0; i < num; ++i)
ret_data[i] = sampler->Draw(); out[i] = sampler->Draw();
delete sampler; delete sampler;
return ret;
} }
template IdArray RandomEngine::Choice<int32_t, float>( template void RandomEngine::Choice<int32_t, float>(
int64_t num, FloatArray prob, bool replace); int32_t num, FloatArray prob, int32_t* out, bool replace);
template IdArray RandomEngine::Choice<int64_t, float>( template void RandomEngine::Choice<int64_t, float>(
int64_t num, FloatArray prob, bool replace); int64_t num, FloatArray prob, int64_t* out, bool replace);
template IdArray RandomEngine::Choice<int32_t, double>( template void RandomEngine::Choice<int32_t, double>(
int64_t num, FloatArray prob, bool replace); int32_t num, FloatArray prob, int32_t* out, bool replace);
template IdArray RandomEngine::Choice<int64_t, double>( template void RandomEngine::Choice<int64_t, double>(
int64_t num, FloatArray prob, bool replace); int64_t num, FloatArray prob, int64_t* out, bool replace);
template <typename IdxType> template <typename IdxType>
IdArray RandomEngine::UniformChoice(int64_t num, int64_t population, bool replace) { void RandomEngine::UniformChoice(IdxType num, IdxType population, IdxType* out, bool replace) {
if (!replace) if (!replace)
CHECK_LE(num, population) << "Cannot take more sample than population when 'replace=false'"; CHECK_LE(num, population) << "Cannot take more sample than population when 'replace=false'";
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
IdxType* ret_data = static_cast<IdxType*>(ret->data);
if (replace) { if (replace) {
for (int64_t i = 0; i < num; ++i) for (IdxType i = 0; i < num; ++i)
ret_data[i] = RandInt(population); out[i] = RandInt(population);
} else {
if (num < population / 10) { // TODO(minjie): may need a better threshold here
// use hash set
// In the best scenario, time complexity is O(num), i.e., no conflict.
//
// Let k be num / population, the expected number of extra sampling steps is roughly
// k^2 / (1-k) * population, which means in the worst case scenario,
// the time complexity is O(population^2). In practice, we use 1/10 since
// std::unordered_set is pretty slow.
std::unordered_set<IdxType> selected;
while (selected.size() < num) {
selected.insert(RandInt(population));
}
std::copy(selected.begin(), selected.end(), out);
} else { } else {
// reservoir algorithm
// time: O(population), space: O(num) // time: O(population), space: O(num)
for (int64_t i = 0; i < num; ++i) for (IdxType i = 0; i < num; ++i)
ret_data[i] = i; out[i] = i;
for (uint64_t i = num; i < population; ++i) { for (IdxType i = num; i < population; ++i) {
const int64_t j = RandInt(i); const IdxType j = RandInt(i);
if (j < num) if (j < num)
ret_data[j] = i; out[j] = i;
}
} }
} }
return ret;
} }
template IdArray RandomEngine::UniformChoice<int32_t>( template void RandomEngine::UniformChoice<int32_t>(
int64_t num, int64_t population, bool replace); int32_t num, int32_t population, int32_t* out, bool replace);
template IdArray RandomEngine::UniformChoice<int64_t>( template void RandomEngine::UniformChoice<int64_t>(
int64_t num, int64_t population, bool replace); int64_t num, int64_t population, int64_t* out, bool replace);
}; // namespace dgl }; // namespace dgl
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/array.h>
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -15,10 +16,38 @@ namespace dgl { ...@@ -15,10 +16,38 @@ namespace dgl {
DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed") DGL_REGISTER_GLOBAL("rng._CAPI_SetSeed")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([] (DGLArgs args, DGLRetValue *rv) {
int seed = args[0]; const int seed = args[0];
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < omp_get_max_threads(); ++i) for (int i = 0; i < omp_get_max_threads(); ++i)
RandomEngine::ThreadLocal()->SetSeed(seed); RandomEngine::ThreadLocal()->SetSeed(seed);
}); });
DGL_REGISTER_GLOBAL("rng._CAPI_Choice")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
const int64_t num = args[0];
const int64_t population = args[1];
const NDArray prob = args[2];
const bool replace = args[3];
const int bits = args[4];
CHECK(bits == 32 || bits == 64)
<< "Supported bit widths are 32 and 64, but got " << bits << ".";
if (aten::IsNullArray(prob)) {
if (bits == 32) {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int32_t>(num, population, replace);
} else {
*rv = RandomEngine::ThreadLocal()->UniformChoice<int64_t>(num, population, replace);
}
} else {
if (bits == 32) {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int32_t, FloatType>(num, prob, replace);
});
} else {
ATEN_FLOAT_TYPE_SWITCH(prob->dtype, FloatType, "probability", {
*rv = RandomEngine::ThreadLocal()->Choice<int64_t, FloatType>(num, prob, replace);
});
}
}
});
}; // namespace dgl }; // namespace dgl
...@@ -29,7 +29,7 @@ struct TypeManager { ...@@ -29,7 +29,7 @@ struct TypeManager {
}; };
} // namespace } // namespace
const bool Object::_DerivedFrom(uint32_t tid) const { bool Object::_DerivedFrom(uint32_t tid) const {
static uint32_t tindex = TypeKey2Index(Object::_type_key); static uint32_t tindex = TypeKey2Index(Object::_type_key);
return tid == tindex; return tid == tindex;
} }
......
import dgl
import backend as F
import numpy as np
import unittest
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU random choice not implemented")
def test_rand_graph():
g = dgl.rand_graph(10000, 100000)
assert g.number_of_nodes() == 10000
assert g.number_of_edges() == 100000
# test random seed
dgl.random.seed(42)
g1 = dgl.rand_graph(100, 30)
dgl.random.seed(42)
g2 = dgl.rand_graph(100, 30)
u1, v1 = g1.edges()
u2, v2 = g2.edges()
assert F.array_equal(u1, u2)
assert F.array_equal(v1, v2)
if __name__ == '__main__':
test_rand_graph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment