Unverified Commit a9520f71 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model][Sampler] GraphSAGE model, bipartite graph conversion & remove edges API (#1297)

* remove edge and to bipartite and graphsage with sampling

* fixes

* fixes

* fixes

* reenable multigpu training

* fixes

* compatibility in DGLGraph

* rename to compact_as_bipartite

* bugfix

* lint

* add offline inference

* skip GPU tests

* fix

* addresses comments

* fix

* fix

* fix

* more tests

* more docs and unit tests

* workaround for empty slice on empty data
parent ce6e19f2
...@@ -12,7 +12,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph( ...@@ -12,7 +12,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty"; CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges()); std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices()); std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
// Loop over all canonical etypes // Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
......
...@@ -83,6 +83,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -83,6 +83,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
: BaseHeteroGraph(metagraph), adj_(coo) { : BaseHeteroGraph(metagraph), adj_(coo) {
// Data index should not be inherited. Edges in COO format are always // Data index should not be inherited. Edges in COO format are always
// assigned ids from 0 to num_edges - 1. // assigned ids from 0 to num_edges - 1.
CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
adj_.data = aten::NullArray(); adj_.data = aten::NullArray();
} }
...@@ -344,7 +345,7 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -344,7 +345,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override { SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
LOG(FATAL) << "Not enabled for COO graph"; LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::ANY; return SparseFormat::kAny;
} }
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
...@@ -443,6 +444,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -443,6 +444,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK(aten::IsValidIdArray(edge_ids)); CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0]) CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length"; << "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
} }
...@@ -724,7 +726,7 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -724,7 +726,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override { SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
LOG(FATAL) << "Not enabled for CSR graph"; LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::ANY; return SparseFormat::kAny;
} }
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override { HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
...@@ -801,11 +803,11 @@ bool UnitGraph::IsMultigraph() const { ...@@ -801,11 +803,11 @@ bool UnitGraph::IsMultigraph() const {
} }
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const { uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY); const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
// TODO(BarclayII): we have a lot of special handling for CSC. // TODO(BarclayII): we have a lot of special handling for CSC.
// Need to have a UnitGraph::CSC backend instead. // Need to have a UnitGraph::CSC backend instead.
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
vtype = (vtype == SrcType()) ? DstType() : SrcType(); vtype = (vtype == SrcType()) ? DstType() : SrcType();
return ptr->NumVertices(vtype); return ptr->NumVertices(vtype);
} }
...@@ -815,9 +817,9 @@ uint64_t UnitGraph::NumEdges(dgl_type_t etype) const { ...@@ -815,9 +817,9 @@ uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
} }
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const { bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY); const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
vtype = (vtype == SrcType()) ? DstType() : SrcType(); vtype = (vtype == SrcType()) ? DstType() : SrcType();
return ptr->HasVertex(vtype, vid); return ptr->HasVertex(vtype, vid);
} }
...@@ -828,9 +830,9 @@ BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const { ...@@ -828,9 +830,9 @@ BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
} }
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY); const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return ptr->HasEdgeBetween(etype, dst, src); return ptr->HasEdgeBetween(etype, dst, src);
else else
return ptr->HasEdgeBetween(etype, src, dst); return ptr->HasEdgeBetween(etype, src, dst);
...@@ -838,42 +840,42 @@ bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) con ...@@ -838,42 +840,42 @@ bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) con
BoolArray UnitGraph::HasEdgesBetween( BoolArray UnitGraph::HasEdgesBetween(
dgl_type_t etype, IdArray src, IdArray dst) const { dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY); const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return ptr->HasEdgesBetween(etype, dst, src); return ptr->HasEdgesBetween(etype, dst, src);
else else
return ptr->HasEdgesBetween(etype, src, dst); return ptr->HasEdgesBetween(etype, src, dst);
} }
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const { IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSC); const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return ptr->Successors(etype, dst); return ptr->Successors(etype, dst);
else else
return ptr->Predecessors(etype, dst); return ptr->Predecessors(etype, dst);
} }
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const { IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSR); const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->Successors(etype, src); return ptr->Successors(etype, src);
} }
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY); const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return ptr->EdgeId(etype, dst, src); return ptr->EdgeId(etype, dst, src);
else else
return ptr->EdgeId(etype, src, dst); return ptr->EdgeId(etype, src, dst);
} }
EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const { EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY); const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) { if (fmt == SparseFormat::kCSC) {
EdgeArray edges = ptr->EdgeIds(etype, dst, src); EdgeArray edges = ptr->EdgeIds(etype, dst, src);
return EdgeArray{edges.dst, edges.src, edges.id}; return EdgeArray{edges.dst, edges.src, edges.id};
} else { } else {
...@@ -882,21 +884,21 @@ EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const { ...@@ -882,21 +884,21 @@ EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const {
} }
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const { std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
const SparseFormat fmt = SelectFormat(SparseFormat::COO); const SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->FindEdge(etype, eid); return ptr->FindEdge(etype, eid);
} }
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const { EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
const SparseFormat fmt = SelectFormat(SparseFormat::COO); const SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->FindEdges(etype, eids); return ptr->FindEdges(etype, eids);
} }
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const { EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSC); const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) { if (fmt == SparseFormat::kCSC) {
const EdgeArray& ret = ptr->OutEdges(etype, vid); const EdgeArray& ret = ptr->OutEdges(etype, vid);
return {ret.dst, ret.src, ret.id}; return {ret.dst, ret.src, ret.id};
} else { } else {
...@@ -905,9 +907,9 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const { ...@@ -905,9 +907,9 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
} }
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const { EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSC); const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) { if (fmt == SparseFormat::kCSC) {
const EdgeArray& ret = ptr->OutEdges(etype, vids); const EdgeArray& ret = ptr->OutEdges(etype, vids);
return {ret.dst, ret.src, ret.id}; return {ret.dst, ret.src, ret.id};
} else { } else {
...@@ -916,13 +918,13 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const { ...@@ -916,13 +918,13 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
} }
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const { EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSR); const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutEdges(etype, vid); return ptr->OutEdges(etype, vid);
} }
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const { EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSR); const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutEdges(etype, vids); return ptr->OutEdges(etype, vids);
} }
...@@ -930,79 +932,79 @@ EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const { ...@@ -930,79 +932,79 @@ EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const { EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
SparseFormat fmt; SparseFormat fmt;
if (order == std::string("eid")) { if (order == std::string("eid")) {
fmt = SelectFormat(SparseFormat::COO); fmt = SelectFormat(SparseFormat::kCOO);
} else if (order.empty()) { } else if (order.empty()) {
// arbitrary order // arbitrary order
fmt = SelectFormat(SparseFormat::ANY); fmt = SelectFormat(SparseFormat::kAny);
} else if (order == std::string("srcdst")) { } else if (order == std::string("srcdst")) {
fmt = SelectFormat(SparseFormat::CSR); fmt = SelectFormat(SparseFormat::kCSR);
} else { } else {
LOG(FATAL) << "Unsupported order request: " << order; LOG(FATAL) << "Unsupported order request: " << order;
return {}; return {};
} }
const auto& edges = GetFormat(fmt)->Edges(etype, order); const auto& edges = GetFormat(fmt)->Edges(etype, order);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return EdgeArray{edges.dst, edges.src, edges.id}; return EdgeArray{edges.dst, edges.src, edges.id};
else else
return edges; return edges;
} }
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const { uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSC); SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return ptr->OutDegree(etype, vid); return ptr->OutDegree(etype, vid);
else else
return ptr->InDegree(etype, vid); return ptr->InDegree(etype, vid);
} }
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const { DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSC); SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return ptr->OutDegrees(etype, vids); return ptr->OutDegrees(etype, vids);
else else
return ptr->InDegrees(etype, vids); return ptr->InDegrees(etype, vids);
} }
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const { uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSR); SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutDegree(etype, vid); return ptr->OutDegree(etype, vid);
} }
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const { DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSR); SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutDegrees(etype, vids); return ptr->OutDegrees(etype, vids);
} }
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSR); SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->SuccVec(etype, vid); return ptr->SuccVec(etype, vid);
} }
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSR); SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
return ptr->OutEdgeVec(etype, vid); return ptr->OutEdgeVec(etype, vid);
} }
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSC); SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return ptr->SuccVec(etype, vid); return ptr->SuccVec(etype, vid);
else else
return ptr->PredVec(etype, vid); return ptr->PredVec(etype, vid);
} }
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const { DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSC); SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt); const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) if (fmt == SparseFormat::kCSC)
return ptr->OutEdgeVec(etype, vid); return ptr->OutEdgeVec(etype, vid);
else else
return ptr->InEdgeVec(etype, vid); return ptr->InEdgeVec(etype, vid);
...@@ -1030,7 +1032,7 @@ std::vector<IdArray> UnitGraph::GetAdj( ...@@ -1030,7 +1032,7 @@ std::vector<IdArray> UnitGraph::GetAdj(
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const { HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
// We prefer to generate a subgraph from out-csr. // We prefer to generate a subgraph from out-csr.
SparseFormat fmt = SelectFormat(SparseFormat::CSR); SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids); HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph); CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
HeteroSubgraph ret; HeteroSubgraph ret;
...@@ -1042,7 +1044,7 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const ...@@ -1042,7 +1044,7 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
HeteroSubgraph UnitGraph::EdgeSubgraph( HeteroSubgraph UnitGraph::EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes) const { const std::vector<IdArray>& eids, bool preserve_nodes) const {
SparseFormat fmt = SelectFormat(SparseFormat::COO); SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes); auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph); COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
HeteroSubgraph ret; HeteroSubgraph ret;
...@@ -1100,6 +1102,28 @@ HeteroGraphPtr UnitGraph::CreateFromCSR( ...@@ -1100,6 +1102,28 @@ HeteroGraphPtr UnitGraph::CreateFromCSR(
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
} }
HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, restrict_format));
}
HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, mat));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, restrict_format));
}
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) { if (g->NumBits() == bits) {
return g; return g;
...@@ -1143,9 +1167,9 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c ...@@ -1143,9 +1167,9 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c
// If the graph is hypersparse and in COO format, switch the restricted format to COO. // If the graph is hypersparse and in COO format, switch the restricted format to COO.
// If the graph is given as CSR, the indptr array is already materialized so we don't // If the graph is given as CSR, the indptr array is already materialized so we don't
// care about restricting conversion anyway (even if it is hypersparse). // care about restricting conversion anyway (even if it is hypersparse).
if (restrict_format == SparseFormat::ANY) { if (restrict_format == SparseFormat::kAny) {
if (coo && coo->IsHypersparse()) if (coo && coo->IsHypersparse())
restrict_format_ = SparseFormat::COO; restrict_format_ = SparseFormat::kCOO;
} }
CHECK(GetAny()) << "At least one graph structure should exist."; CHECK(GetAny()) << "At least one graph structure should exist.";
...@@ -1221,31 +1245,31 @@ HeteroGraphPtr UnitGraph::GetAny() const { ...@@ -1221,31 +1245,31 @@ HeteroGraphPtr UnitGraph::GetAny() const {
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const { HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
switch (format) { switch (format) {
case SparseFormat::CSR: case SparseFormat::kCSR:
return GetOutCSR(); return GetOutCSR();
case SparseFormat::CSC: case SparseFormat::kCSC:
return GetInCSR(); return GetInCSR();
case SparseFormat::COO: case SparseFormat::kCOO:
return GetCOO(); return GetCOO();
case SparseFormat::ANY: case SparseFormat::kAny:
return GetAny(); return GetAny();
default: default:
LOG(FATAL) << "unsupported format code"; LOG(FATAL) << "unsupported format code";
return nullptr; return nullptr;
} }
} }
SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const { SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
if (restrict_format_ != SparseFormat::ANY) if (restrict_format_ != SparseFormat::kAny)
return restrict_format_; return restrict_format_;
else if (preferred_format != SparseFormat::ANY) else if (preferred_format != SparseFormat::kAny)
return preferred_format; return preferred_format;
else if (in_csr_) else if (in_csr_)
return SparseFormat::CSC; return SparseFormat::kCSC;
else if (out_csr_) else if (out_csr_)
return SparseFormat::CSR; return SparseFormat::kCSR;
else else
return SparseFormat::COO; return SparseFormat::kCOO;
} }
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127; constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
...@@ -1260,13 +1284,13 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1260,13 +1284,13 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
restrict_format_ = static_cast<SparseFormat>(format_code); restrict_format_ = static_cast<SparseFormat>(format_code);
switch (restrict_format_) { switch (restrict_format_) {
case SparseFormat::COO: case SparseFormat::kCOO:
fs->Read(&coo_); fs->Read(&coo_);
break; break;
case SparseFormat::CSR: case SparseFormat::kCSR:
fs->Read(&out_csr_); fs->Read(&out_csr_);
break; break;
case SparseFormat::CSC: case SparseFormat::kCSC:
fs->Read(&in_csr_); fs->Read(&in_csr_);
break; break;
default: default:
...@@ -1284,16 +1308,16 @@ void UnitGraph::Save(dmlc::Stream* fs) const { ...@@ -1284,16 +1308,16 @@ void UnitGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_UnitGraphMagic); fs->Write(kDGLSerialize_UnitGraphMagic);
// Didn't write UnitGraph::meta_graph_, since it's included in the underlying // Didn't write UnitGraph::meta_graph_, since it's included in the underlying
// sparse matrix // sparse matrix
auto avail_fmt = SelectFormat(SparseFormat::ANY); auto avail_fmt = SelectFormat(SparseFormat::kAny);
fs->Write(static_cast<int64_t>(avail_fmt)); fs->Write(static_cast<int64_t>(avail_fmt));
switch (avail_fmt) { switch (avail_fmt) {
case SparseFormat::COO: case SparseFormat::kCOO:
fs->Write(GetCOO()); fs->Write(GetCOO());
break; break;
case SparseFormat::CSR: case SparseFormat::kCSR:
fs->Write(GetOutCSR()); fs->Write(GetOutCSR());
break; break;
case SparseFormat::CSC: case SparseFormat::kCSC:
fs->Write(GetInCSR()); fs->Write(GetInCSR());
break; break;
default: default:
......
...@@ -166,21 +166,31 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -166,21 +166,31 @@ class UnitGraph : public BaseHeteroGraph {
/*! \brief Create a graph from COO arrays */ /*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::ANY); IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::kAny);
static HeteroGraphPtr CreateFromCOO( static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format = SparseFormat::ANY); SparseFormat restrict_format = SparseFormat::kAny);
/*! \brief Create a graph from (out) CSR arrays */ /*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::ANY); SparseFormat restrict_format = SparseFormat::kAny);
static HeteroGraphPtr CreateFromCSR( static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::ANY); SparseFormat restrict_format = SparseFormat::kAny);
/*! \brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny);
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny);
/*! \brief Convert the graph to use the given number of bits for storage */ /*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
...@@ -231,7 +241,7 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -231,7 +241,7 @@ class UnitGraph : public BaseHeteroGraph {
* \param coo coo * \param coo coo
*/ */
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
SparseFormat restrict_format = SparseFormat::ANY); SparseFormat restrict_format = SparseFormat::kAny);
/*! \return Return any existing format. */ /*! \return Return any existing format. */
HeteroGraphPtr GetAny() const; HeteroGraphPtr GetAny() const;
......
...@@ -395,6 +395,121 @@ def test_to_simple(): ...@@ -395,6 +395,121 @@ def test_to_simple():
for i, e in enumerate(uv): for i, e in enumerate(uv):
assert eid_map[i] == suv.index(e) assert eid_map[i] == suv.index(e)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU compaction not implemented")
def test_to_block():
def check(g, bg, ntype, etype, rhs_nodes):
if rhs_nodes is not None:
assert F.array_equal(bg.nodes[ntype + '_r'].data[dgl.NID], rhs_nodes)
n_rhs_nodes = bg.number_of_nodes(ntype + '_r')
assert F.array_equal(
bg.nodes[ntype + '_l'].data[dgl.NID][:n_rhs_nodes],
bg.nodes[ntype + '_r'].data[dgl.NID])
g = g[etype]
bg = bg[etype]
induced_src = bg.srcdata[dgl.NID]
induced_dst = bg.dstdata[dgl.NID]
induced_eid = bg.edata[dgl.EID]
bg_src, bg_dst = bg.all_edges(order='eid')
src_ans, dst_ans = g.all_edges(order='eid')
induced_src_bg = F.gather_row(induced_src, bg_src)
induced_dst_bg = F.gather_row(induced_dst, bg_dst)
induced_src_ans = F.gather_row(src_ans, induced_eid)
induced_dst_ans = F.gather_row(dst_ans, induced_eid)
assert F.array_equal(induced_src_bg, induced_src_ans)
assert F.array_equal(induced_dst_bg, induced_dst_ans)
def checkall(g, bg, rhs_nodes):
for etype in g.etypes:
ntype = g.to_canonical_etype(etype)[2]
if rhs_nodes is not None and ntype in rhs_nodes:
check(g, bg, ntype, etype, rhs_nodes[ntype])
else:
check(g, bg, ntype, etype, None)
g = dgl.heterograph({
('A', 'AA', 'A'): [(0, 1), (2, 3), (1, 2), (3, 4)],
('A', 'AB', 'B'): [(0, 1), (1, 3), (3, 5), (1, 6)],
('B', 'BA', 'A'): [(2, 3), (3, 2)]})
g_a = g['AA']
bg = dgl.to_block(g_a)
check(g_a, bg, 'A', 'AA', None)
rhs_nodes = F.tensor([3, 4], dtype=F.int64)
bg = dgl.to_block(g_a, rhs_nodes)
check(g_a, bg, 'A', 'AA', rhs_nodes)
rhs_nodes = F.tensor([4, 3, 2, 1], dtype=F.int64)
bg = dgl.to_block(g_a, rhs_nodes)
check(g_a, bg, 'A', 'AA', rhs_nodes)
g_ab = g['AB']
bg = dgl.to_block(g_ab)
assert bg.number_of_nodes('B_l') == 4
assert F.array_equal(bg.nodes['B_l'].data[dgl.NID], bg.nodes['B_r'].data[dgl.NID])
assert bg.number_of_nodes('A_r') == 0
checkall(g_ab, bg, None)
rhs_nodes = {'B': F.tensor([5, 6], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes)
assert bg.number_of_nodes('B_l') == 2
assert F.array_equal(bg.nodes['B_l'].data[dgl.NID], bg.nodes['B_r'].data[dgl.NID])
assert bg.number_of_nodes('A_r') == 0
checkall(g, bg, rhs_nodes)
rhs_nodes = {'A': F.tensor([3, 4], dtype=F.int64), 'B': F.tensor([5, 6], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes)
checkall(g, bg, rhs_nodes)
rhs_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=F.int64), 'B': F.tensor([3, 5, 6, 1], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes=rhs_nodes)
checkall(g, bg, rhs_nodes)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_remove_edges():
def check(g1, etype, g, edges_removed):
src, dst, eid = g.edges(etype=etype, form='all')
src1, dst1 = g1.edges(etype=etype, order='eid')
if etype is not None:
eid1 = g1.edges[etype].data[dgl.EID]
else:
eid1 = g1.edata[dgl.EID]
src1 = F.asnumpy(src1)
dst1 = F.asnumpy(dst1)
eid1 = F.asnumpy(eid1)
src = F.asnumpy(src)
dst = F.asnumpy(dst)
eid = F.asnumpy(eid)
sde_set = set(zip(src, dst, eid))
for s, d, e in zip(src1, dst1, eid1):
assert (s, d, e) in sde_set
assert not np.isin(edges_removed, eid1).any()
for fmt in ['coo', 'csr', 'csc']:
for edges_to_remove in [[2], [2, 2], [3, 2], [1, 3, 1, 2]]:
g = dgl.graph([(0, 1), (2, 3), (1, 2), (3, 4)], restrict_format=fmt)
g1 = dgl.remove_edges(g, F.tensor(edges_to_remove))
check(g1, None, g, edges_to_remove)
g = dgl.graph(
spsp.csr_matrix(([1, 1, 1, 1], ([0, 2, 1, 3], [1, 3, 2, 4])), shape=(5, 5)),
restrict_format=fmt)
g1 = dgl.remove_edges(g, F.tensor(edges_to_remove))
check(g1, None, g, edges_to_remove)
g = dgl.heterograph({
('A', 'AA', 'A'): [(0, 1), (2, 3), (1, 2), (3, 4)],
('A', 'AB', 'B'): [(0, 1), (1, 3), (3, 5), (1, 6)],
('B', 'BA', 'A'): [(2, 3), (3, 2)]})
g2 = dgl.remove_edges(g, {'AA': F.tensor([2]), 'AB': F.tensor([3]), 'BA': F.tensor([1])})
check(g2, 'AA', g, [2])
check(g2, 'AB', g, [3])
check(g2, 'BA', g, [1])
if __name__ == '__main__': if __name__ == '__main__':
test_line_graph() test_line_graph()
...@@ -413,3 +528,5 @@ if __name__ == '__main__': ...@@ -413,3 +528,5 @@ if __name__ == '__main__':
test_to_simple() test_to_simple()
test_in_subgraph() test_in_subgraph()
test_out_subgraph() test_out_subgraph()
test_to_block()
test_remove_edges()
...@@ -18,7 +18,7 @@ TEST(Serialize, UnitGraph_COO) { ...@@ -18,7 +18,7 @@ TEST(Serialize, UnitGraph_COO) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = std::dynamic_pointer_cast<UnitGraph>( auto mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::COO)); dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::kCOO));
std::string blob; std::string blob;
dmlc::MemoryStringStream ifs(&blob); dmlc::MemoryStringStream ifs(&blob);
...@@ -40,7 +40,7 @@ TEST(Serialize, UnitGraph_CSR) { ...@@ -40,7 +40,7 @@ TEST(Serialize, UnitGraph_CSR) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3}); auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6}); auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = std::dynamic_pointer_cast<UnitGraph>( auto mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::CSR)); dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::kCSR));
std::string blob; std::string blob;
dmlc::MemoryStringStream ifs(&blob); dmlc::MemoryStringStream ifs(&blob);
......
...@@ -402,6 +402,15 @@ def test_sage_conv(): ...@@ -402,6 +402,15 @@ def test_sage_conv():
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == 10 assert h.shape[-1] == 10
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
dst_dim = 5 if aggre_type != 'gcn' else 10
sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
sage = sage.to(ctx)
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 200
def test_sgc_conv(): def test_sgc_conv():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment