Unverified Commit feabcabd authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Kernel]GraphOp::Reverse (#1730)



* Compile OK

* Fix compile

* Explore reverse to heterograph_indx and add python test code

* lint

* Fix some bug

* Fix bug

* upd

* Fix

* fix lint

* Fix

* Fix

* test

* triger

* Fix

* upd

* add TODO

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarZihao Ye <expye@outlook.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent cadcc1c2
...@@ -991,6 +991,16 @@ class HeteroGraphIndex(ObjectBase): ...@@ -991,6 +991,16 @@ class HeteroGraphIndex(ObjectBase):
""" """
return _CAPI_DGLHeteroGetFormatGraph(self, restrict_format) return _CAPI_DGLHeteroGetFormatGraph(self, restrict_format)
def reverse(self):
"""Reverse the heterogeneous graph adjacency
The node types and edge types are not changed
Returns
-------
A new graph index.
"""
return _CAPI_DGLHeteroReverse(self)
@register_object('graph.HeteroSubgraph') @register_object('graph.HeteroSubgraph')
class HeteroSubgraphIndex(ObjectBase): class HeteroSubgraphIndex(ObjectBase):
......
...@@ -220,6 +220,10 @@ class HeteroGraph : public BaseHeteroGraph { ...@@ -220,6 +220,10 @@ class HeteroGraph : public BaseHeteroGraph {
/*! \brief Copy the data to another context */ /*! \brief Copy the data to another context */
static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx); static HeteroGraphPtr CopyTo(HeteroGraphPtr g, const DLContext& ctx);
const std::vector<UnitGraphPtr>& relation_graphs() const {
return relation_graphs_;
}
private: private:
// To create empty class // To create empty class
friend class Serializer; friend class Serializer;
......
...@@ -596,4 +596,22 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes") ...@@ -596,4 +596,22 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
ret_list.push_back(dstlist); ret_list.push_back(dstlist);
*rv = ret_list; *rv = ret_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
CHECK_GT(hg->NumEdgeTypes(), 0);
auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
std::vector<HeteroGraphPtr> rev_ugs;
const auto &ugs = g->relation_graphs();
rev_ugs.resize(ugs.size());
for (size_t i = 0; i < ugs.size(); ++i) {
const auto &rev_ug = ugs[i]->Reverse();
rev_ugs[i] = rev_ug;
}
// node types are not changed
const auto& num_nodes = g->NumVerticesPerType();
*rv = CreateHeteroGraph(hg->meta_graph(), rev_ugs, num_nodes);
});
} // namespace dgl } // namespace dgl
...@@ -77,6 +77,17 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -77,6 +77,17 @@ class UnitGraph::COO : public BaseHeteroGraph {
adj_.data = aten::NullArray(); adj_.data = aten::NullArray();
} }
COO() {
// set magic num_rows/num_cols to mark it as undefined
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
adj_.num_rows = -1;
adj_.num_cols = -1;
};
bool defined() const {
return (adj_.num_rows >= 0) && (adj_.num_cols >= 0);
}
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const {
return 0; return 0;
} }
...@@ -415,8 +426,6 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -415,8 +426,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
private: private:
friend class Serializer; friend class Serializer;
COO() {}
/*! \brief internal adjacency matrix. Data array is empty */ /*! \brief internal adjacency matrix. Data array is empty */
aten::COOMatrix adj_; aten::COOMatrix adj_;
}; };
...@@ -446,6 +455,17 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -446,6 +455,17 @@ class UnitGraph::CSR : public BaseHeteroGraph {
: BaseHeteroGraph(metagraph), adj_(csr) { : BaseHeteroGraph(metagraph), adj_(csr) {
} }
CSR() {
// set magic num_rows/num_cols to mark it as undefined
// adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is supported
adj_.num_rows = -1;
adj_.num_cols = -1;
};
bool defined() const {
return (adj_.num_rows >= 0) || (adj_.num_cols >= 0);
}
inline dgl_type_t SrcType() const { inline dgl_type_t SrcType() const {
return 0; return 0;
} }
...@@ -767,12 +787,9 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -767,12 +787,9 @@ class UnitGraph::CSR : public BaseHeteroGraph {
fs->Write(adj_); fs->Write(adj_);
} }
private: private:
friend class Serializer; friend class Serializer;
CSR() {};
/*! \brief internal adjacency matrix. Data array stores edge ids */ /*! \brief internal adjacency matrix. Data array stores edge ids */
aten::CSRMatrix adj_; aten::CSRMatrix adj_;
}; };
...@@ -1171,9 +1188,12 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -1171,9 +1188,12 @@ HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
} else { } else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g); auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg); CHECK_NOTNULL(bg);
CSRPtr new_incsr = (bg->in_csr_)? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr; CSRPtr new_incsr =
CSRPtr new_outcsr = (bg->out_csr_)? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr; (bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr;
COOPtr new_coo = (bg->coo_)? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr; CSRPtr new_outcsr =
(bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr;
COOPtr new_coo =
(bg->coo_->defined())? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr;
return HeteroGraphPtr( return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->restrict_format_)); new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->restrict_format_));
} }
...@@ -1185,9 +1205,12 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { ...@@ -1185,9 +1205,12 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
} else { } else {
auto bg = std::dynamic_pointer_cast<UnitGraph>(g); auto bg = std::dynamic_pointer_cast<UnitGraph>(g);
CHECK_NOTNULL(bg); CHECK_NOTNULL(bg);
CSRPtr new_incsr = (bg->in_csr_)? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr; CSRPtr new_incsr =
CSRPtr new_outcsr = (bg->out_csr_)? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr; (bg->in_csr_->defined())? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr;
COOPtr new_coo = (bg->coo_)? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr; CSRPtr new_outcsr =
(bg->out_csr_->defined())? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr;
COOPtr new_coo =
(bg->coo_->defined())? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr;
return HeteroGraphPtr( return HeteroGraphPtr(
new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->restrict_format_)); new UnitGraph(g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->restrict_format_));
} }
...@@ -1196,22 +1219,35 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) { ...@@ -1196,22 +1219,35 @@ HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DLContext& ctx) {
UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo,
SparseFormat restrict_format) SparseFormat restrict_format)
: BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) { : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) {
restrict_format_ = AutoDetectFormat(in_csr, out_csr, coo, restrict_format); if (!in_csr_) {
in_csr_ = CSRPtr(new CSR());
}
if (!out_csr_) {
out_csr_ = CSRPtr(new CSR());
}
if (!coo_) {
coo_ = COOPtr(new COO());
}
restrict_format_ = AutoDetectFormat(in_csr_, out_csr_, coo_, restrict_format);
switch (restrict_format) { switch (restrict_format) {
case SparseFormat::kCSC: case SparseFormat::kCSC:
in_csr_ = GetInCSR(); in_csr_ = GetInCSR();
coo_ = nullptr; // cleaning other format
out_csr_ = nullptr; out_csr_ = out_csr_->defined() ? CSRPtr(new CSR()) : out_csr_;
coo_ = coo_->defined() ? COOPtr(new COO()) : coo_;
break; break;
case SparseFormat::kCSR: case SparseFormat::kCSR:
out_csr_ = GetOutCSR(); out_csr_ = GetOutCSR();
coo_ = nullptr; // cleaning other format
in_csr_ = nullptr; in_csr_ = in_csr_->defined() ? CSRPtr(new CSR()) : in_csr_;
coo_ = coo_->defined() ? COOPtr(new COO()) : coo_;
break; break;
case SparseFormat::kCOO: case SparseFormat::kCOO:
coo_ = GetCOO(); coo_ = GetCOO();
in_csr_ = nullptr; // cleaning other format
out_csr_ = nullptr; in_csr_ = in_csr_->defined() ? CSRPtr(new CSR()) : in_csr_;
out_csr_ = out_csr_->defined() ? CSRPtr(new CSR()) : out_csr_;
break; break;
default: default:
break; break;
...@@ -1236,10 +1272,16 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom( ...@@ -1236,10 +1272,16 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom(
if (has_in_csr) if (has_in_csr)
in_csr_ptr = CSRPtr(new CSR(mg, in_csr)); in_csr_ptr = CSRPtr(new CSR(mg, in_csr));
else
in_csr_ptr = CSRPtr(new CSR());
if (has_out_csr) if (has_out_csr)
out_csr_ptr = CSRPtr(new CSR(mg, out_csr)); out_csr_ptr = CSRPtr(new CSR(mg, out_csr));
else
out_csr_ptr = CSRPtr(new CSR());
if (has_coo) if (has_coo)
coo_ptr = COOPtr(new COO(mg, coo)); coo_ptr = COOPtr(new COO(mg, coo));
else
coo_ptr = COOPtr(new COO());
return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, restrict_format)); return HeteroGraphPtr(new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, restrict_format));
} }
...@@ -1251,19 +1293,23 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1251,19 +1293,23 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() << LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create CSC matrix."; ", cannot create CSC matrix.";
CSRPtr ret = in_csr_; CSRPtr ret = in_csr_;
if (!in_csr_) { if (!in_csr_->defined()) {
if (out_csr_) { if (out_csr_->defined()) {
const auto& newadj = aten::CSRTranspose(out_csr_->adj()); const auto& newadj = aten::CSRTranspose(out_csr_->adj());
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->in_csr_ = ret; *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
else
ret = std::make_shared<CSR>(meta_graph(), newadj);
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_->defined()) << "None of CSR, COO exist";
const auto& newadj = aten::CSRSort(aten::COOToCSR( const auto& newadj = aten::CSRSort(aten::COOToCSR(
aten::COOTranspose(coo_->adj()))); aten::COOTranspose(coo_->adj())));
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->in_csr_ = ret; *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
else
ret = std::make_shared<CSR>(meta_graph(), newadj);
} }
} }
return ret; return ret;
...@@ -1277,18 +1323,22 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1277,18 +1323,22 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() << LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create CSR matrix."; ", cannot create CSR matrix.";
CSRPtr ret = out_csr_; CSRPtr ret = out_csr_;
if (!out_csr_) { if (!out_csr_->defined()) {
if (in_csr_) { if (in_csr_->defined()) {
const auto& newadj = aten::CSRSort(aten::CSRTranspose(in_csr_->adj())); const auto& newadj = aten::CSRSort(aten::CSRTranspose(in_csr_->adj()));
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret; *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
else
ret = std::make_shared<CSR>(meta_graph(), newadj);
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_->defined()) << "None of CSR, COO exist";
const auto& newadj = aten::CSRSort(aten::COOToCSR(coo_->adj())); const auto& newadj = aten::CSRSort(aten::COOToCSR(coo_->adj()));
ret = std::make_shared<CSR>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->out_csr_ = ret; *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
else
ret = std::make_shared<CSR>(meta_graph(), newadj);
} }
} }
return ret; return ret;
...@@ -1302,18 +1352,22 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { ...@@ -1302,18 +1352,22 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() << LOG(FATAL) << "The graph have restricted sparse format " << GetRestrictFormat() <<
", cannot create COO matrix."; ", cannot create COO matrix.";
COOPtr ret = coo_; COOPtr ret = coo_;
if (!coo_) { if (!coo_->defined()) {
if (in_csr_) { if (in_csr_->defined()) {
const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true)); const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
ret = std::make_shared<COO>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->coo_ = ret; *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
else
ret = std::make_shared<COO>(meta_graph(), newadj);
} else { } else {
CHECK(out_csr_) << "Both CSR are missing."; CHECK(out_csr_->defined()) << "Both CSR are missing.";
const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true); const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true);
ret = std::make_shared<COO>(meta_graph(), newadj);
if (inplace) if (inplace)
const_cast<UnitGraph*>(this)->coo_ = ret; *(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), newadj);
else
ret = std::make_shared<COO>(meta_graph(), newadj);
} }
} }
return ret; return ret;
...@@ -1332,9 +1386,9 @@ aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const { ...@@ -1332,9 +1386,9 @@ aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
} }
HeteroGraphPtr UnitGraph::GetAny() const { HeteroGraphPtr UnitGraph::GetAny() const {
if (in_csr_) { if (in_csr_->defined()) {
return in_csr_; return in_csr_;
} else if (out_csr_) { } else if (out_csr_->defined()) {
return out_csr_; return out_csr_;
} else { } else {
return coo_; return coo_;
...@@ -1343,11 +1397,11 @@ HeteroGraphPtr UnitGraph::GetAny() const { ...@@ -1343,11 +1397,11 @@ HeteroGraphPtr UnitGraph::GetAny() const {
dgl_format_code_t UnitGraph::GetFormatInUse() const { dgl_format_code_t UnitGraph::GetFormatInUse() const {
dgl_format_code_t ret = 0; dgl_format_code_t ret = 0;
if (in_csr_) ret = ret | 1; if (in_csr_->defined()) ret = ret | 1;
ret = ret << 1; ret = ret << 1;
if (out_csr_) ret = ret | 1; if (out_csr_->defined()) ret = ret | 1;
ret = ret << 1; ret = ret << 1;
if (coo_) ret = ret | 1; if (coo_->defined()) ret = ret | 1;
return ret; return ret;
} }
...@@ -1381,7 +1435,12 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(SparseFormat restrict_format) const { ...@@ -1381,7 +1435,12 @@ HeteroGraphPtr UnitGraph::GetGraphInFormat(SparseFormat restrict_format) const {
num_vtypes, GetOutCSR(false)->adj(), restrict_format); num_vtypes, GetOutCSR(false)->adj(), restrict_format);
case SparseFormat::kAny: case SparseFormat::kAny:
return HeteroGraphPtr( return HeteroGraphPtr(
new UnitGraph(meta_graph_, in_csr_, out_csr_, coo_, restrict_format)); // TODO(xiangsx) Make it as graph storage.Clone()
new UnitGraph(meta_graph_,
(in_csr_->defined()) ? CSRPtr(new CSR(*in_csr_)) : nullptr,
(out_csr_->defined()) ? CSRPtr(new CSR(*out_csr_)) : nullptr,
(coo_->defined()) ? COOPtr(new COO(*coo_)) : nullptr,
restrict_format));
default: // SparseFormat::kAuto default: // SparseFormat::kAuto
LOG(FATAL) << "Must specify a restrict format."; LOG(FATAL) << "Must specify a restrict format.";
return nullptr; return nullptr;
...@@ -1392,7 +1451,7 @@ SparseFormat UnitGraph::AutoDetectFormat( ...@@ -1392,7 +1451,7 @@ SparseFormat UnitGraph::AutoDetectFormat(
CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, SparseFormat restrict_format) const { CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, SparseFormat restrict_format) const {
if (restrict_format != SparseFormat::kAuto) if (restrict_format != SparseFormat::kAuto)
return restrict_format; return restrict_format;
if (coo && coo->IsHypersparse()) if (coo && coo->defined() && coo->IsHypersparse())
return SparseFormat::kCOO; return SparseFormat::kCOO;
return SparseFormat::kAny; return SparseFormat::kAny;
} }
...@@ -1402,9 +1461,9 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const { ...@@ -1402,9 +1461,9 @@ SparseFormat UnitGraph::SelectFormat(SparseFormat preferred_format) const {
return restrict_format_; // force to select the restricted format return restrict_format_; // force to select the restricted format
else if (preferred_format != SparseFormat::kAny) else if (preferred_format != SparseFormat::kAny)
return preferred_format; return preferred_format;
else if (in_csr_) else if (in_csr_->defined())
return SparseFormat::kCSC; return SparseFormat::kCSC;
else if (out_csr_) else if (out_csr_->defined())
return SparseFormat::kCSR; return SparseFormat::kCSR;
else else
return SparseFormat::kCOO; return SparseFormat::kCOO;
...@@ -1414,15 +1473,15 @@ GraphPtr UnitGraph::AsImmutableGraph() const { ...@@ -1414,15 +1473,15 @@ GraphPtr UnitGraph::AsImmutableGraph() const {
CHECK(NumVertexTypes() == 1) << "not a homogeneous graph"; CHECK(NumVertexTypes() == 1) << "not a homogeneous graph";
dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr; dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr;
dgl::COOPtr coo_ptr = nullptr; dgl::COOPtr coo_ptr = nullptr;
if (in_csr_) { if (in_csr_->defined()) {
aten::CSRMatrix csc = GetCSCMatrix(0); aten::CSRMatrix csc = GetCSCMatrix(0);
in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data)); in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data));
} }
if (out_csr_) { if (out_csr_->defined()) {
aten::CSRMatrix csr = GetCSRMatrix(0); aten::CSRMatrix csr = GetCSRMatrix(0);
out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data)); out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data));
} }
if (coo_) { if (coo_->defined()) {
aten::COOMatrix coo = GetCOOMatrix(0); aten::COOMatrix coo = GetCOOMatrix(0);
if (!COOHasData(coo)) { if (!COOHasData(coo)) {
coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col)); coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col));
...@@ -1463,6 +1522,16 @@ bool UnitGraph::Load(dmlc::Stream* fs) { ...@@ -1463,6 +1522,16 @@ bool UnitGraph::Load(dmlc::Stream* fs) {
break; break;
} }
if (!in_csr_) {
in_csr_ = CSRPtr(new CSR());
}
if (!out_csr_) {
out_csr_ = CSRPtr(new CSR());
}
if (!coo_) {
coo_ = COOPtr(new COO());
}
meta_graph_ = GetAny()->meta_graph(); meta_graph_ = GetAny()->meta_graph();
return true; return true;
...@@ -1492,4 +1561,14 @@ void UnitGraph::Save(dmlc::Stream* fs) const { ...@@ -1492,4 +1561,14 @@ void UnitGraph::Save(dmlc::Stream* fs) const {
} }
} }
UnitGraphPtr UnitGraph::Reverse() const {
CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_;
COOPtr new_coo = nullptr;
if (coo_->defined()) {
new_coo = COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj())));
}
return UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo));
}
} // namespace dgl } // namespace dgl
...@@ -266,6 +266,9 @@ class UnitGraph : public BaseHeteroGraph { ...@@ -266,6 +266,9 @@ class UnitGraph : public BaseHeteroGraph {
/*! \return Save UnitGraph to stream, using CSRMatrix */ /*! \return Save UnitGraph to stream, using CSRMatrix */
void Save(dmlc::Stream* fs) const; void Save(dmlc::Stream* fs) const;
/*! \return the reversed graph */
UnitGraphPtr Reverse() const;
private: private:
friend class Serializer; friend class Serializer;
friend class HeteroGraph; friend class HeteroGraph;
......
...@@ -843,4 +843,4 @@ if __name__ == '__main__': ...@@ -843,4 +843,4 @@ if __name__ == '__main__':
# test_group_apply_edges() # test_group_apply_edges()
# test_local_var() # test_local_var()
# test_local_scope() # test_local_scope()
# test_issue_1088() test_issue_1088('int64')
...@@ -1929,6 +1929,134 @@ def test_edges_order(): ...@@ -1929,6 +1929,134 @@ def test_edges_order():
assert F.array_equal(F.copy_to(dst, F.cpu()), assert F.array_equal(F.copy_to(dst, F.cpu()),
F.copy_to(F.tensor([1, 1, 2, 2, 1]), F.cpu())) F.copy_to(F.tensor([1, 1, 2, 2, 1]), F.cpu()))
@parametrize_dtype
def test_reverse(index_dtype):
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 4, 3 ,1, 3], [1, 2, 3, 2, 0, 0, 1]),
}, index_dtype=index_dtype)
gidx = g._graph
r_gidx = gidx.reverse()
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0)
g_s, g_d, _ = gidx.edges(0)
rg_s, rg_d, _ = r_gidx.edges(0)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
# force to start with 'csr'
gidx = gidx.to_format('csr')
gidx = gidx.to_format('any')
r_gidx = gidx.reverse()
assert gidx.format_in_use(0)[0] == 'csr'
assert r_gidx.format_in_use(0)[0] == 'csc'
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0)
g_s, g_d, _ = gidx.edges(0)
rg_s, rg_d, _ = r_gidx.edges(0)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
# force to start with 'csc'
gidx = gidx.to_format('csc')
gidx = gidx.to_format('any')
r_gidx = gidx.reverse()
assert gidx.format_in_use(0)[0] == 'csc'
assert r_gidx.format_in_use(0)[0] == 'csr'
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0)
g_s, g_d, _ = gidx.edges(0)
rg_s, rg_d, _ = r_gidx.edges(0)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 4, 3 ,1, 3], [1, 2, 3, 2, 0, 0, 1]),
('user', 'plays', 'game'): ([0, 0, 2, 3, 3, 4, 1], [1, 0, 1, 0, 1, 0, 0]),
('developer', 'develops', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1]),
}, index_dtype=index_dtype)
gidx = g._graph
r_gidx = gidx.reverse()
# three node types and three edge types
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_nodes(1) == r_gidx.number_of_nodes(1)
assert gidx.number_of_nodes(2) == r_gidx.number_of_nodes(2)
assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0)
assert gidx.number_of_edges(1) == r_gidx.number_of_edges(1)
assert gidx.number_of_edges(2) == r_gidx.number_of_edges(2)
g_s, g_d, _ = gidx.edges(0)
rg_s, rg_d, _ = r_gidx.edges(0)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
g_s, g_d, _ = gidx.edges(1)
rg_s, rg_d, _ = r_gidx.edges(1)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
g_s, g_d, _ = gidx.edges(2)
rg_s, rg_d, _ = r_gidx.edges(2)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
# force to start with 'csr'
gidx = gidx.to_format('csr')
gidx = gidx.to_format('any')
r_gidx = gidx.reverse()
# three node types and three edge types
assert gidx.format_in_use(0)[0] == 'csr'
assert r_gidx.format_in_use(0)[0] == 'csc'
assert gidx.format_in_use(1)[0] == 'csr'
assert r_gidx.format_in_use(1)[0] == 'csc'
assert gidx.format_in_use(2)[0] == 'csr'
assert r_gidx.format_in_use(2)[0] == 'csc'
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_nodes(1) == r_gidx.number_of_nodes(1)
assert gidx.number_of_nodes(2) == r_gidx.number_of_nodes(2)
assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0)
assert gidx.number_of_edges(1) == r_gidx.number_of_edges(1)
assert gidx.number_of_edges(2) == r_gidx.number_of_edges(2)
g_s, g_d, _ = gidx.edges(0)
rg_s, rg_d, _ = r_gidx.edges(0)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
g_s, g_d, _ = gidx.edges(1)
rg_s, rg_d, _ = r_gidx.edges(1)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
g_s, g_d, _ = gidx.edges(2)
rg_s, rg_d, _ = r_gidx.edges(2)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
# force to start with 'csc'
gidx = gidx.to_format('csc')
gidx = gidx.to_format('any')
r_gidx = gidx.reverse()
# three node types and three edge types
assert gidx.format_in_use(0)[0] == 'csc'
assert r_gidx.format_in_use(0)[0] == 'csr'
assert gidx.format_in_use(1)[0] == 'csc'
assert r_gidx.format_in_use(1)[0] == 'csr'
assert gidx.format_in_use(2)[0] == 'csc'
assert r_gidx.format_in_use(2)[0] == 'csr'
assert gidx.number_of_nodes(0) == r_gidx.number_of_nodes(0)
assert gidx.number_of_nodes(1) == r_gidx.number_of_nodes(1)
assert gidx.number_of_nodes(2) == r_gidx.number_of_nodes(2)
assert gidx.number_of_edges(0) == r_gidx.number_of_edges(0)
assert gidx.number_of_edges(1) == r_gidx.number_of_edges(1)
assert gidx.number_of_edges(2) == r_gidx.number_of_edges(2)
g_s, g_d, _ = gidx.edges(0)
rg_s, rg_d, _ = r_gidx.edges(0)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
g_s, g_d, _ = gidx.edges(1)
rg_s, rg_d, _ = r_gidx.edges(1)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
g_s, g_d, _ = gidx.edges(2)
rg_s, rg_d, _ = r_gidx.edges(2)
assert F.array_equal(g_s.tousertensor(), rg_d.tousertensor())
assert F.array_equal(g_d.tousertensor(), rg_s.tousertensor())
if __name__ == '__main__': if __name__ == '__main__':
# test_create() # test_create()
# test_query() # test_query()
...@@ -1942,17 +2070,20 @@ if __name__ == '__main__': ...@@ -1942,17 +2070,20 @@ if __name__ == '__main__':
# test_convert() # test_convert()
# test_to_device() # test_to_device()
# test_transform("int32") # test_transform("int32")
test_subgraph("int32") # test_subgraph("int32")
test_subgraph_mask("int32") # test_subgraph_mask("int32")
# test_apply() # test_apply()
# test_level1() # test_level1()
# test_level2() # test_level2()
# test_updates() # test_updates()
# test_backward() # test_backward()
# test_empty_heterograph() # test_empty_heterograph('int32')
# test_types_in_function() # test_types_in_function()
# test_stack_reduce() # test_stack_reduce()
# test_isolated_ntype() # test_isolated_ntype()
# test_bipartite() # test_bipartite()
# test_dtype_cast() # test_dtype_cast()
# test_reverse("int32")
test_format()
pass pass
/*!
* Copyright (c) 2019 by Contributors
* \file test_unit_graph.cc
* \brief Test UnitGraph
*/
#include <gtest/gtest.h>
#include <dgl/array.h>
#include <vector>
#include <dgl/immutable_graph.h>
#include "./common.h"
#include "./../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"
using namespace dgl;
using namespace dgl::aten;
using namespace dgl::runtime;
template <typename IdType>
aten::CSRMatrix CSR1(DLContext ctx) {
/*
* G = [[0, 0, 1],
* [1, 0, 1],
* [0, 1, 0],
* [1, 0, 1]]
*/
IdArray g_indptr =
aten::VecToIdArray(std::vector<IdType>({0, 1, 3, 4, 6}), sizeof(IdType)*8, CTX);
IdArray g_indices =
aten::VecToIdArray(std::vector<IdType>({2, 0, 2, 1, 0, 2}), sizeof(IdType)*8, CTX);
const aten::CSRMatrix &csr_a = aten::CSRMatrix(
4,
3,
g_indptr,
g_indices,
aten::NullArray(),
false);
return csr_a;
}
template aten::CSRMatrix CSR1<int32_t>(DLContext ctx);
template aten::CSRMatrix CSR1<int64_t>(DLContext ctx);
template <typename IdType>
aten::COOMatrix COO1(DLContext ctx) {
/*
* G = [[1, 1, 0],
* [0, 1, 0]]
*/
IdArray g_row =
aten::VecToIdArray(std::vector<IdType>({0, 0, 1}), sizeof(IdType)*8, CTX);
IdArray g_col =
aten::VecToIdArray(std::vector<IdType>({0, 1, 1}), sizeof(IdType)*8, CTX);
const aten::COOMatrix &coo = aten::COOMatrix(
2,
3,
g_row,
g_col,
aten::NullArray(),
true,
true);
return coo;
}
template aten::COOMatrix COO1<int32_t>(DLContext ctx);
template aten::COOMatrix COO1<int64_t>(DLContext ctx);
template <typename IdType>
void _TestUnitGraph(DLContext ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
const aten::COOMatrix &coo = COO1<IdType>(ctx);
auto hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSC(2, csr, SparseFormat::kAny));
UnitGraphPtr g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 4);
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSR(2, csr, SparseFormat::kAny));
g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 2);
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCOO(2, coo, SparseFormat::kAny));
g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 1);
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::kCOO));
ASSERT_EQ(mg->GetFormatInUse(), 1);
auto hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, dgl::SparseFormat::kCOO);
auto img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
ASSERT_TRUE(img != nullptr);
mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::kCSR));
ASSERT_EQ(mg->GetFormatInUse(), 2);
hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, dgl::SparseFormat::kCSR);
img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
ASSERT_TRUE(img != nullptr);
mg = std::dynamic_pointer_cast<UnitGraph>(
dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst, dgl::SparseFormat::kCSC));
ASSERT_EQ(mg->GetFormatInUse(), 4);
hmg = dgl::UnitGraph::CreateFromCOO(1, 8, 8, src, dst, dgl::SparseFormat::kCSC);
img = std::dynamic_pointer_cast<ImmutableGraph>(hmg->AsImmutableGraph());
ASSERT_TRUE(img != nullptr);
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSC(2, csr, SparseFormat::kAuto));
g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 4);
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSR(2, csr, SparseFormat::kAuto));
g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 2);
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCOO(2, coo, SparseFormat::kAuto));
g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 1);
}
template <typename IdType>
void _TestUnitGraph_GetInCSR(DLContext ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
const aten::COOMatrix &coo = COO1<IdType>(ctx);
auto hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSC(2, csr, SparseFormat::kAny));
UnitGraphPtr g = hg->relation_graphs()[0];
auto in_csr_matrix = g->GetCSCMatrix(0);
ASSERT_EQ(in_csr_matrix.num_rows, csr.num_rows);
ASSERT_EQ(in_csr_matrix.num_cols, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 4);
// test out csr
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSR(2, csr, SparseFormat::kAny));
g = hg->relation_graphs()[0];
UnitGraphPtr g_ptr = std::dynamic_pointer_cast<UnitGraph>(g->GetGraphInFormat(SparseFormat::kCSC));
in_csr_matrix = g_ptr->GetCSCMatrix(0);
ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);
ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 2);
in_csr_matrix = g->GetCSCMatrix(0);
ASSERT_EQ(in_csr_matrix.num_cols, csr.num_rows);
ASSERT_EQ(in_csr_matrix.num_rows, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 6);
// test out coo
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCOO(2, coo, SparseFormat::kAny));
g = hg->relation_graphs()[0];
g_ptr = std::dynamic_pointer_cast<UnitGraph>(g->GetGraphInFormat(SparseFormat::kCSC));
in_csr_matrix = g_ptr->GetCSCMatrix(0);
ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);
ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 1);
in_csr_matrix = g->GetCSCMatrix(0);
ASSERT_EQ(in_csr_matrix.num_cols, coo.num_rows);
ASSERT_EQ(in_csr_matrix.num_rows, coo.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 5);
}
template <typename IdType>
void _TestUnitGraph_GetOutCSR(DLContext ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
const aten::COOMatrix &coo = COO1<IdType>(ctx);
auto hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSC(2, csr, SparseFormat::kAny));
UnitGraphPtr g = hg->relation_graphs()[0];
UnitGraphPtr g_ptr = std::dynamic_pointer_cast<UnitGraph>(g->GetGraphInFormat(SparseFormat::kCSR));
auto out_csr_matrix = g_ptr->GetCSRMatrix(0);
ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);
ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 4);
out_csr_matrix = g->GetCSRMatrix(0);
ASSERT_EQ(out_csr_matrix.num_cols, csr.num_rows);
ASSERT_EQ(out_csr_matrix.num_rows, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 6);
// test out csr
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSR(2, csr, SparseFormat::kAny));
g = hg->relation_graphs()[0];
out_csr_matrix = g->GetCSRMatrix(0);
ASSERT_EQ(out_csr_matrix.num_rows, csr.num_rows);
ASSERT_EQ(out_csr_matrix.num_cols, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 2);
// test out coo
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCOO(2, coo, SparseFormat::kAny));
g = hg->relation_graphs()[0];
g_ptr = std::dynamic_pointer_cast<UnitGraph>(g->GetGraphInFormat(SparseFormat::kCSR));
out_csr_matrix = g_ptr->GetCSRMatrix(0);
ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);
ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 1);
out_csr_matrix = g->GetCSRMatrix(0);
ASSERT_EQ(out_csr_matrix.num_rows, coo.num_rows);
ASSERT_EQ(out_csr_matrix.num_cols, coo.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 3);
}
template <typename IdType>
void _TestUnitGraph_GetCOO(DLContext ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
const aten::COOMatrix &coo = COO1<IdType>(ctx);
auto hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSC(2, csr, SparseFormat::kAny));
UnitGraphPtr g = hg->relation_graphs()[0];
UnitGraphPtr g_ptr = std::dynamic_pointer_cast<UnitGraph>(g->GetGraphInFormat(SparseFormat::kCOO));
auto out_coo_matrix = g_ptr->GetCOOMatrix(0);
ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);
ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 4);
out_coo_matrix = g->GetCOOMatrix(0);
ASSERT_EQ(out_coo_matrix.num_cols, csr.num_rows);
ASSERT_EQ(out_coo_matrix.num_rows, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 5);
// test out csr
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSR(2, csr, SparseFormat::kAny));
g = hg->relation_graphs()[0];
g_ptr = std::dynamic_pointer_cast<UnitGraph>(g->GetGraphInFormat(SparseFormat::kCOO));
out_coo_matrix = g_ptr->GetCOOMatrix(0);
ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);
ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 2);
out_coo_matrix = g->GetCOOMatrix(0);
ASSERT_EQ(out_coo_matrix.num_rows, csr.num_rows);
ASSERT_EQ(out_coo_matrix.num_cols, csr.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 3);
// test out coo
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCOO(2, coo, SparseFormat::kAny));
g = hg->relation_graphs()[0];
out_coo_matrix = g->GetCOOMatrix(0);
ASSERT_EQ(out_coo_matrix.num_rows, coo.num_rows);
ASSERT_EQ(out_coo_matrix.num_cols, coo.num_cols);
ASSERT_EQ(g->GetFormatInUse(), 1);
}
template <typename IdType>
void _TestUnitGraph_Reserve(DLContext ctx) {
const aten::CSRMatrix &csr = CSR1<IdType>(ctx);
const aten::COOMatrix &coo = COO1<IdType>(ctx);
auto hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSC(2, csr, SparseFormat::kAny));
UnitGraphPtr g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 4);
UnitGraphPtr r_g = g->Reverse();
ASSERT_EQ(r_g->GetFormatInUse(), 2);
aten::CSRMatrix g_in_csr = g->GetCSCMatrix(0);
aten::CSRMatrix r_g_out_csr = r_g->GetCSRMatrix(0);
ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
aten::CSRMatrix g_out_csr = g->GetCSRMatrix(0);
ASSERT_EQ(g->GetFormatInUse(), 6);
ASSERT_EQ(r_g->GetFormatInUse(), 6);
aten::CSRMatrix r_g_in_csr = r_g->GetCSCMatrix(0);
ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
aten::COOMatrix g_coo = g->GetCOOMatrix(0);
ASSERT_EQ(g->GetFormatInUse(), 7);
ASSERT_EQ(r_g->GetFormatInUse(), 6);
aten::COOMatrix r_g_coo = r_g->GetCOOMatrix(0);
ASSERT_EQ(r_g->GetFormatInUse(), 7);
ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
ASSERT_TRUE(ArrayEQ<IdType>(g_coo.row, r_g_coo.col));
ASSERT_TRUE(ArrayEQ<IdType>(g_coo.col, r_g_coo.row));
// test out csr
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCSR(2, csr, SparseFormat::kAny));
g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 2);
r_g = g->Reverse();
ASSERT_EQ(r_g->GetFormatInUse(), 4);
g_out_csr = g->GetCSRMatrix(0);
r_g_in_csr = r_g->GetCSCMatrix(0);
ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
g_in_csr = g->GetCSCMatrix(0);
ASSERT_EQ(g->GetFormatInUse(), 6);
ASSERT_EQ(r_g->GetFormatInUse(), 6);
r_g_out_csr = r_g->GetCSRMatrix(0);
ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
g_coo = g->GetCOOMatrix(0);
ASSERT_EQ(g->GetFormatInUse(), 7);
ASSERT_EQ(r_g->GetFormatInUse(), 6);
r_g_coo = r_g->GetCOOMatrix(0);
ASSERT_EQ(r_g->GetFormatInUse(), 7);
ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
ASSERT_TRUE(ArrayEQ<IdType>(g_coo.row, r_g_coo.col));
ASSERT_TRUE(ArrayEQ<IdType>(g_coo.col, r_g_coo.row));
// test out coo
hg = std::dynamic_pointer_cast<HeteroGraph>(CreateFromCOO(2, coo, SparseFormat::kAny));
g = hg->relation_graphs()[0];
ASSERT_EQ(g->GetFormatInUse(), 1);
r_g = g->Reverse();
ASSERT_EQ(r_g->GetFormatInUse(), 1);
g_coo = g->GetCOOMatrix(0);
r_g_coo = r_g->GetCOOMatrix(0);
ASSERT_EQ(g_coo.num_rows, r_g_coo.num_cols);
ASSERT_EQ(g_coo.num_cols, r_g_coo.num_rows);
ASSERT_TRUE(g_coo.row->data == r_g_coo.col->data);
ASSERT_TRUE(g_coo.col->data == r_g_coo.row->data);
g_in_csr = g->GetCSCMatrix(0);
ASSERT_EQ(g->GetFormatInUse(), 5);
ASSERT_EQ(r_g->GetFormatInUse(), 3);
r_g_out_csr = r_g->GetCSRMatrix(0);
ASSERT_TRUE(g_in_csr.indptr->data == r_g_out_csr.indptr->data);
ASSERT_TRUE(g_in_csr.indices->data == r_g_out_csr.indices->data);
g_out_csr = g->GetCSRMatrix(0);
ASSERT_EQ(g->GetFormatInUse(), 7);
ASSERT_EQ(r_g->GetFormatInUse(), 7);
r_g_in_csr = r_g->GetCSCMatrix(0);
ASSERT_TRUE(g_out_csr.indptr->data == r_g_in_csr.indptr->data);
ASSERT_TRUE(g_out_csr.indices->data == r_g_in_csr.indices->data);
}
TEST(UniGraphTest, TestUnitGraph_Create) {
_TestUnitGraph<int32_t>(CPU);
_TestUnitGraph<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph<int32_t>(GPU);
_TestUnitGraph<int64_t>(GPU);
#endif
}
TEST(UniGraphTest, TestUnitGraph_GetInCSR) {
_TestUnitGraph_GetInCSR<int32_t>(CPU);
_TestUnitGraph_GetInCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph_GetInCSR<int32_t>(GPU);
_TestUnitGraph_GetInCSR<int64_t>(GPU);
#endif
}
TEST(UniGraphTest, TestUnitGraph_GetOutCSR) {
_TestUnitGraph_GetOutCSR<int32_t>(CPU);
_TestUnitGraph_GetOutCSR<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph_GetOutCSR<int32_t>(GPU);
_TestUnitGraph_GetOutCSR<int64_t>(GPU);
#endif
}
TEST(UniGraphTest, TestUnitGraph_GetCOO) {
_TestUnitGraph_GetCOO<int32_t>(CPU);
_TestUnitGraph_GetCOO<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph_GetCOO<int32_t>(GPU);
_TestUnitGraph_GetCOO<int64_t>(GPU);
#endif
}
TEST(UniGraphTest, TestUnitGraph_Reserve) {
_TestUnitGraph_Reserve<int32_t>(CPU);
_TestUnitGraph_Reserve<int64_t>(CPU);
#ifdef DGL_USE_CUDA
_TestUnitGraph_Reserve<int32_t>(GPU);
_TestUnitGraph_Reserve<int64_t>(GPU);
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment