Unverified Commit 05c53ca3 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Performance] Prefer parallelized conversion to CSC from COO instead of transposing CSR (#2793)

* fix coo2csr speed

* add comments
parent 86229d42
...@@ -1310,18 +1310,20 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1310,18 +1310,20 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
LOG(FATAL) << "The graph have restricted sparse format " << LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot create CSC matrix."; CodeToStr(formats_) << ", cannot create CSC matrix.";
CSRPtr ret = in_csr_; CSRPtr ret = in_csr_;
// Prefers converting from COO since it is parallelized.
// TODO(BarclayII): need benchmarking.
if (!in_csr_->defined()) { if (!in_csr_->defined()) {
if (out_csr_->defined()) { if (coo_->defined()) {
const auto& newadj = aten::CSRTranspose(out_csr_->adj()); const auto& newadj = aten::COOToCSR(
aten::COOTranspose(coo_->adj()));
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
else else
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} else { } else {
CHECK(coo_->defined()) << "None of CSR, COO exist"; CHECK(out_csr_->defined()) << "None of CSR, COO exist";
const auto& newadj = aten::COOToCSR( const auto& newadj = aten::CSRTranspose(out_csr_->adj());
aten::COOTranspose(coo_->adj()));
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), newadj);
...@@ -1339,17 +1341,19 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1339,17 +1341,19 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
LOG(FATAL) << "The graph have restricted sparse format " << LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot create CSR matrix."; CodeToStr(formats_) << ", cannot create CSR matrix.";
CSRPtr ret = out_csr_; CSRPtr ret = out_csr_;
// Prefers converting from COO since it is parallelized.
// TODO(BarclayII): need benchmarking.
if (!out_csr_->defined()) { if (!out_csr_->defined()) {
if (in_csr_->defined()) { if (coo_->defined()) {
const auto& newadj = aten::CSRTranspose(in_csr_->adj()); const auto& newadj = aten::COOToCSR(coo_->adj());
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
else else
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} else { } else {
CHECK(coo_->defined()) << "None of CSR, COO exist"; CHECK(in_csr_->defined()) << "None of CSR, COO exist";
const auto& newadj = aten::COOToCSR(coo_->adj()); const auto& newadj = aten::CSRTranspose(in_csr_->adj());
if (inplace) if (inplace)
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj); *(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), newadj);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment