Unverified Commit a9520f71 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model][Sampler] GraphSAGE model, bipartite graph conversion & remove edges API (#1297)

* remove edge and to bipartite and graphsage with sampling

* fixes

* fixes

* fixes

* reenable multigpu training

* fixes

* compatibility in DGLGraph

* rename to compact_as_bipartite

* bugfix

* lint

* add offline inference

* skip GPU tests

* fix

* addresses comments

* fix

* fix

* fix

* more tests

* more docs and unit tests

* workaround for empty slice on empty data
parent ce6e19f2
......@@ -12,7 +12,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
......
......@@ -83,6 +83,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
: BaseHeteroGraph(metagraph), adj_(coo) {
// Data index should not be inherited. Edges in COO format are always
// assigned ids from 0 to num_edges - 1.
CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data.";
adj_.data = aten::NullArray();
}
......@@ -344,7 +345,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::ANY;
return SparseFormat::kAny;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
......@@ -443,6 +444,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
}
......@@ -724,7 +726,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
SparseFormat SelectFormat(dgl_type_t etype, SparseFormat preferred_format) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::ANY;
return SparseFormat::kAny;
}
HeteroSubgraph VertexSubgraph(const std::vector<IdArray>& vids) const override {
......@@ -801,11 +803,11 @@ bool UnitGraph::IsMultigraph() const {
}
uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt);
// TODO(BarclayII): we have a lot of special handling for CSC.
// Need to have a UnitGraph::CSC backend instead.
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
vtype = (vtype == SrcType()) ? DstType() : SrcType();
return ptr->NumVertices(vtype);
}
......@@ -815,9 +817,9 @@ uint64_t UnitGraph::NumEdges(dgl_type_t etype) const {
}
bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
vtype = (vtype == SrcType()) ? DstType() : SrcType();
return ptr->HasVertex(vtype, vid);
}
......@@ -828,9 +830,9 @@ BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
}
bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return ptr->HasEdgeBetween(etype, dst, src);
else
return ptr->HasEdgeBetween(etype, src, dst);
......@@ -838,42 +840,42 @@ bool UnitGraph::HasEdgeBetween(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) con
BoolArray UnitGraph::HasEdgesBetween(
dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return ptr->HasEdgesBetween(etype, dst, src);
else
return ptr->HasEdgesBetween(etype, src, dst);
}
IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSC);
const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return ptr->Successors(etype, dst);
else
return ptr->Predecessors(etype, dst);
}
IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSR);
const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt);
return ptr->Successors(etype, src);
}
IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return ptr->EdgeId(etype, dst, src);
else
return ptr->EdgeId(etype, src, dst);
}
EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const {
const SparseFormat fmt = SelectFormat(SparseFormat::ANY);
const SparseFormat fmt = SelectFormat(SparseFormat::kAny);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) {
if (fmt == SparseFormat::kCSC) {
EdgeArray edges = ptr->EdgeIds(etype, dst, src);
return EdgeArray{edges.dst, edges.src, edges.id};
} else {
......@@ -882,21 +884,21 @@ EdgeArray UnitGraph::EdgeIds(dgl_type_t etype, IdArray src, IdArray dst) const {
}
std::pair<dgl_id_t, dgl_id_t> UnitGraph::FindEdge(dgl_type_t etype, dgl_id_t eid) const {
const SparseFormat fmt = SelectFormat(SparseFormat::COO);
const SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
const auto ptr = GetFormat(fmt);
return ptr->FindEdge(etype, eid);
}
EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const {
const SparseFormat fmt = SelectFormat(SparseFormat::COO);
const SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
const auto ptr = GetFormat(fmt);
return ptr->FindEdges(etype, eids);
}
EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSC);
const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) {
if (fmt == SparseFormat::kCSC) {
const EdgeArray& ret = ptr->OutEdges(etype, vid);
return {ret.dst, ret.src, ret.id};
} else {
......@@ -905,9 +907,9 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const {
}
EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSC);
const SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC) {
if (fmt == SparseFormat::kCSC) {
const EdgeArray& ret = ptr->OutEdges(etype, vids);
return {ret.dst, ret.src, ret.id};
} else {
......@@ -916,13 +918,13 @@ EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const {
}
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSR);
const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt);
return ptr->OutEdges(etype, vid);
}
EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
const SparseFormat fmt = SelectFormat(SparseFormat::CSR);
const SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt);
return ptr->OutEdges(etype, vids);
}
......@@ -930,79 +932,79 @@ EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const {
EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string &order) const {
SparseFormat fmt;
if (order == std::string("eid")) {
fmt = SelectFormat(SparseFormat::COO);
fmt = SelectFormat(SparseFormat::kCOO);
} else if (order.empty()) {
// arbitrary order
fmt = SelectFormat(SparseFormat::ANY);
fmt = SelectFormat(SparseFormat::kAny);
} else if (order == std::string("srcdst")) {
fmt = SelectFormat(SparseFormat::CSR);
fmt = SelectFormat(SparseFormat::kCSR);
} else {
LOG(FATAL) << "Unsupported order request: " << order;
return {};
}
const auto& edges = GetFormat(fmt)->Edges(etype, order);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return EdgeArray{edges.dst, edges.src, edges.id};
else
return edges;
}
uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSC);
SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return ptr->OutDegree(etype, vid);
else
return ptr->InDegree(etype, vid);
}
DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSC);
SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return ptr->OutDegrees(etype, vids);
else
return ptr->InDegrees(etype, vids);
}
uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSR);
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt);
return ptr->OutDegree(etype, vid);
}
DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSR);
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt);
return ptr->OutDegrees(etype, vids);
}
DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSR);
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt);
return ptr->SuccVec(etype, vid);
}
DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSR);
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
const auto ptr = GetFormat(fmt);
return ptr->OutEdgeVec(etype, vid);
}
DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSC);
SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return ptr->SuccVec(etype, vid);
else
return ptr->PredVec(etype, vid);
}
DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const {
SparseFormat fmt = SelectFormat(SparseFormat::CSC);
SparseFormat fmt = SelectFormat(SparseFormat::kCSC);
const auto ptr = GetFormat(fmt);
if (fmt == SparseFormat::CSC)
if (fmt == SparseFormat::kCSC)
return ptr->OutEdgeVec(etype, vid);
else
return ptr->InEdgeVec(etype, vid);
......@@ -1030,7 +1032,7 @@ std::vector<IdArray> UnitGraph::GetAdj(
HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
// We prefer to generate a subgraph from out-csr.
SparseFormat fmt = SelectFormat(SparseFormat::CSR);
SparseFormat fmt = SelectFormat(SparseFormat::kCSR);
HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids);
CSRPtr subcsr = std::dynamic_pointer_cast<CSR>(sg.graph);
HeteroSubgraph ret;
......@@ -1042,7 +1044,7 @@ HeteroSubgraph UnitGraph::VertexSubgraph(const std::vector<IdArray>& vids) const
HeteroSubgraph UnitGraph::EdgeSubgraph(
const std::vector<IdArray>& eids, bool preserve_nodes) const {
SparseFormat fmt = SelectFormat(SparseFormat::COO);
SparseFormat fmt = SelectFormat(SparseFormat::kCOO);
auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes);
COOPtr subcoo = std::dynamic_pointer_cast<COO>(sg.graph);
HeteroSubgraph ret;
......@@ -1100,6 +1102,28 @@ HeteroGraphPtr UnitGraph::CreateFromCSR(
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, restrict_format));
}
HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, SparseFormat restrict_format) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(num_src, num_dst);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, restrict_format));
}
HeteroGraphPtr UnitGraph::CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format) {
CHECK(num_vtypes == 1 || num_vtypes == 2);
if (num_vtypes == 1)
CHECK_EQ(mat.num_rows, mat.num_cols);
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr csc(new CSR(mg, mat));
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, restrict_format));
}
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
if (g->NumBits() == bits) {
return g;
......@@ -1143,9 +1167,9 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c
// If the graph is hypersparse and in COO format, switch the restricted format to COO.
// If the graph is given as CSR, the indptr array is already materialized so we don't
// care about restricting conversion anyway (even if it is hypersparse).
if (restrict_format == SparseFormat::ANY) {
if (restrict_format == SparseFormat::kAny) {
if (coo && coo->IsHypersparse())
restrict_format_ = SparseFormat::COO;
restrict_format_ = SparseFormat::kCOO;
}
CHECK(GetAny()) << "At least one graph structure should exist.";
......@@ -1221,31 +1245,31 @@ HeteroGraphPtr UnitGraph::GetAny() const {
HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const {
switch (format) {
case SparseFormat::CSR:
return GetOutCSR();
case SparseFormat::CSC:
return GetInCSR();
case SparseFormat::COO:
return GetCOO();
case SparseFormat::ANY:
return GetAny();
default:
LOG(FATAL) << "unsupported format code";
return nullptr;
case SparseFormat::kCSR:
return GetOutCSR();
case SparseFormat::kCSC:
return GetInCSR();
case SparseFormat::kCOO:
return GetCOO();
case SparseFormat::kAny:
return GetAny();
default:
LOG(FATAL) << "unsupported format code";
return nullptr;
}
}
SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
if (restrict_format_ != SparseFormat::ANY)
if (restrict_format_ != SparseFormat::kAny)
return restrict_format_;
else if (preferred_format != SparseFormat::ANY)
else if (preferred_format != SparseFormat::kAny)
return preferred_format;
else if (in_csr_)
return SparseFormat::CSC;
return SparseFormat::kCSC;
else if (out_csr_)
return SparseFormat::CSR;
return SparseFormat::kCSR;
else
return SparseFormat::COO;
return SparseFormat::kCOO;
}
constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127;
......@@ -1260,13 +1284,13 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
restrict_format_ = static_cast<SparseFormat>(format_code);
switch (restrict_format_) {
case SparseFormat::COO:
case SparseFormat::kCOO:
fs->Read(&coo_);
break;
case SparseFormat::CSR:
case SparseFormat::kCSR:
fs->Read(&out_csr_);
break;
case SparseFormat::CSC:
case SparseFormat::kCSC:
fs->Read(&in_csr_);
break;
default:
......@@ -1284,16 +1308,16 @@ void UnitGraph::Save(dmlc::Stream* fs) const {
fs->Write(kDGLSerialize_UnitGraphMagic);
// Didn't write UnitGraph::meta_graph_, since it's included in the underlying
// sparse matrix
auto avail_fmt = SelectFormat(SparseFormat::ANY);
auto avail_fmt = SelectFormat(SparseFormat::kAny);
fs->Write(static_cast<int64_t>(avail_fmt));
switch (avail_fmt) {
case SparseFormat::COO:
case SparseFormat::kCOO:
fs->Write(GetCOO());
break;
case SparseFormat::CSR:
case SparseFormat::kCSR:
fs->Write(GetOutCSR());
break;
case SparseFormat::CSC:
case SparseFormat::kCSC:
fs->Write(GetInCSR());
break;
default:
......
......@@ -166,21 +166,31 @@ class UnitGraph : public BaseHeteroGraph {
/*! \brief Create a graph from COO arrays */
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::ANY);
IdArray row, IdArray col, SparseFormat restrict_format = SparseFormat::kAny);
static HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
SparseFormat restrict_format = SparseFormat::ANY);
SparseFormat restrict_format = SparseFormat::kAny);
/*! \brief Create a graph from (out) CSR arrays */
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::ANY);
SparseFormat restrict_format = SparseFormat::kAny);
static HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::ANY);
SparseFormat restrict_format = SparseFormat::kAny);
/*! \brief Create a graph from (in) CSC arrays */
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
SparseFormat restrict_format = SparseFormat::kAny);
static HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
SparseFormat restrict_format = SparseFormat::kAny);
/*! \brief Convert the graph to use the given number of bits for storage */
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
......@@ -231,7 +241,7 @@ class UnitGraph : public BaseHeteroGraph {
* \param coo coo
*/
UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
SparseFormat restrict_format = SparseFormat::ANY);
SparseFormat restrict_format = SparseFormat::kAny);
/*! \return Return any existing format. */
HeteroGraphPtr GetAny() const;
......
......@@ -395,6 +395,121 @@ def test_to_simple():
for i, e in enumerate(uv):
assert eid_map[i] == suv.index(e)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU compaction not implemented")
def test_to_block():
def check(g, bg, ntype, etype, rhs_nodes):
if rhs_nodes is not None:
assert F.array_equal(bg.nodes[ntype + '_r'].data[dgl.NID], rhs_nodes)
n_rhs_nodes = bg.number_of_nodes(ntype + '_r')
assert F.array_equal(
bg.nodes[ntype + '_l'].data[dgl.NID][:n_rhs_nodes],
bg.nodes[ntype + '_r'].data[dgl.NID])
g = g[etype]
bg = bg[etype]
induced_src = bg.srcdata[dgl.NID]
induced_dst = bg.dstdata[dgl.NID]
induced_eid = bg.edata[dgl.EID]
bg_src, bg_dst = bg.all_edges(order='eid')
src_ans, dst_ans = g.all_edges(order='eid')
induced_src_bg = F.gather_row(induced_src, bg_src)
induced_dst_bg = F.gather_row(induced_dst, bg_dst)
induced_src_ans = F.gather_row(src_ans, induced_eid)
induced_dst_ans = F.gather_row(dst_ans, induced_eid)
assert F.array_equal(induced_src_bg, induced_src_ans)
assert F.array_equal(induced_dst_bg, induced_dst_ans)
def checkall(g, bg, rhs_nodes):
for etype in g.etypes:
ntype = g.to_canonical_etype(etype)[2]
if rhs_nodes is not None and ntype in rhs_nodes:
check(g, bg, ntype, etype, rhs_nodes[ntype])
else:
check(g, bg, ntype, etype, None)
g = dgl.heterograph({
('A', 'AA', 'A'): [(0, 1), (2, 3), (1, 2), (3, 4)],
('A', 'AB', 'B'): [(0, 1), (1, 3), (3, 5), (1, 6)],
('B', 'BA', 'A'): [(2, 3), (3, 2)]})
g_a = g['AA']
bg = dgl.to_block(g_a)
check(g_a, bg, 'A', 'AA', None)
rhs_nodes = F.tensor([3, 4], dtype=F.int64)
bg = dgl.to_block(g_a, rhs_nodes)
check(g_a, bg, 'A', 'AA', rhs_nodes)
rhs_nodes = F.tensor([4, 3, 2, 1], dtype=F.int64)
bg = dgl.to_block(g_a, rhs_nodes)
check(g_a, bg, 'A', 'AA', rhs_nodes)
g_ab = g['AB']
bg = dgl.to_block(g_ab)
assert bg.number_of_nodes('B_l') == 4
assert F.array_equal(bg.nodes['B_l'].data[dgl.NID], bg.nodes['B_r'].data[dgl.NID])
assert bg.number_of_nodes('A_r') == 0
checkall(g_ab, bg, None)
rhs_nodes = {'B': F.tensor([5, 6], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes)
assert bg.number_of_nodes('B_l') == 2
assert F.array_equal(bg.nodes['B_l'].data[dgl.NID], bg.nodes['B_r'].data[dgl.NID])
assert bg.number_of_nodes('A_r') == 0
checkall(g, bg, rhs_nodes)
rhs_nodes = {'A': F.tensor([3, 4], dtype=F.int64), 'B': F.tensor([5, 6], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes)
checkall(g, bg, rhs_nodes)
rhs_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=F.int64), 'B': F.tensor([3, 5, 6, 1], dtype=F.int64)}
bg = dgl.to_block(g, rhs_nodes=rhs_nodes)
checkall(g, bg, rhs_nodes)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_remove_edges():
def check(g1, etype, g, edges_removed):
src, dst, eid = g.edges(etype=etype, form='all')
src1, dst1 = g1.edges(etype=etype, order='eid')
if etype is not None:
eid1 = g1.edges[etype].data[dgl.EID]
else:
eid1 = g1.edata[dgl.EID]
src1 = F.asnumpy(src1)
dst1 = F.asnumpy(dst1)
eid1 = F.asnumpy(eid1)
src = F.asnumpy(src)
dst = F.asnumpy(dst)
eid = F.asnumpy(eid)
sde_set = set(zip(src, dst, eid))
for s, d, e in zip(src1, dst1, eid1):
assert (s, d, e) in sde_set
assert not np.isin(edges_removed, eid1).any()
for fmt in ['coo', 'csr', 'csc']:
for edges_to_remove in [[2], [2, 2], [3, 2], [1, 3, 1, 2]]:
g = dgl.graph([(0, 1), (2, 3), (1, 2), (3, 4)], restrict_format=fmt)
g1 = dgl.remove_edges(g, F.tensor(edges_to_remove))
check(g1, None, g, edges_to_remove)
g = dgl.graph(
spsp.csr_matrix(([1, 1, 1, 1], ([0, 2, 1, 3], [1, 3, 2, 4])), shape=(5, 5)),
restrict_format=fmt)
g1 = dgl.remove_edges(g, F.tensor(edges_to_remove))
check(g1, None, g, edges_to_remove)
g = dgl.heterograph({
('A', 'AA', 'A'): [(0, 1), (2, 3), (1, 2), (3, 4)],
('A', 'AB', 'B'): [(0, 1), (1, 3), (3, 5), (1, 6)],
('B', 'BA', 'A'): [(2, 3), (3, 2)]})
g2 = dgl.remove_edges(g, {'AA': F.tensor([2]), 'AB': F.tensor([3]), 'BA': F.tensor([1])})
check(g2, 'AA', g, [2])
check(g2, 'AB', g, [3])
check(g2, 'BA', g, [1])
if __name__ == '__main__':
test_line_graph()
......@@ -413,3 +528,5 @@ if __name__ == '__main__':
test_to_simple()
test_in_subgraph()
test_out_subgraph()
test_to_block()
test_remove_edges()
......@@ -18,7 +18,7 @@ TEST(Serialize, UnitGraph_COO) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::COO));
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::kCOO));
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
......@@ -40,7 +40,7 @@ TEST(Serialize, UnitGraph_CSR) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::CSR));
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::kCSR));
std::string blob;
dmlc::MemoryStringStream ifs(&blob);
......
......@@ -402,6 +402,15 @@ def test_sage_conv():
h = sage(g, feat)
assert h.shape[-1] == 10
g = dgl.bipartite(sp.sparse.random(100, 200, density=0.1))
dst_dim = 5 if aggre_type != 'gcn' else 10
sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
feat = (F.randn((100, 10)), F.randn((200, dst_dim)))
sage = sage.to(ctx)
h = sage(g, feat)
assert h.shape[-1] == 2
assert h.shape[0] == 200
def test_sgc_conv():
ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment