Unverified Commit 223a3da5 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bugfix] Fix multiple bugs and code refactor (#3841)



* fix

* remove setcxx methods

* move pin flag to CSR and COO matrix
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 12e97c54
......@@ -47,6 +47,8 @@ struct COOMatrix {
bool row_sorted = false;
/*! \brief whether the column indices per row are sorted */
bool col_sorted = false;
/*! \brief whether the matrix is in pinned memory */
bool is_pinned = false;
/*! \brief default constructor */
COOMatrix() = default;
/*! \brief constructor */
......@@ -139,11 +141,14 @@ struct COOMatrix {
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
if (is_pinned)
return;
row.PinMemory_();
col.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
is_pinned = true;
}
/*!
......@@ -154,11 +159,14 @@ struct COOMatrix {
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
if (!is_pinned)
return;
row.UnpinMemory_();
col.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
is_pinned = false;
}
};
......
......@@ -44,6 +44,8 @@ struct CSRMatrix {
IdArray data;
/*! \brief whether the column indices per row are sorted */
bool sorted = false;
/*! \brief whether the matrix is in pinned memory */
bool is_pinned = false;
/*! \brief default constructor */
CSRMatrix() = default;
/*! \brief constructor */
......@@ -132,11 +134,14 @@ struct CSRMatrix {
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {
if (is_pinned)
return;
indptr.PinMemory_();
indices.PinMemory_();
if (!aten::IsNullArray(data)) {
data.PinMemory_();
}
is_pinned = true;
}
/*!
......@@ -147,11 +152,14 @@ struct CSRMatrix {
* The context check is deferred to unpinning the NDArray.
*/
inline void UnpinMemory_() {
if (!is_pinned)
return;
indptr.UnpinMemory_();
indices.UnpinMemory_();
if (!aten::IsNullArray(data)) {
data.UnpinMemory_();
}
is_pinned = false;
}
};
......
......@@ -437,21 +437,6 @@ class BaseHeteroGraph : public runtime::Object {
*/
virtual aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const = 0;
/*!
* \brief Set the COO matrix representation for a given edge type.
*/
virtual void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) = 0;
/*!
* \brief Set the CSR matrix representation for a given edge type.
*/
virtual void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) = 0;
/*!
* \brief Set the CSC matrix representation for a given edge type.
*/
virtual void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) = 0;
/*!
* \brief Extract the induced subgraph by the given vertices.
*
......
......@@ -30,7 +30,7 @@ def prepare_tensor(g, data, name):
Data in tensor object.
"""
if F.is_tensor(data):
if not g.is_pinned() and (F.dtype(data) != g.idtype or F.context(data) != g.device):
if (F.dtype(data) != g.idtype or F.context(data) != g.device) and not g.is_pinned():
raise DGLError('Expect argument "{}" to have data type {} and device '
'context {}. But got {} and {}.'.format(
name, g.idtype, g.device, F.dtype(data), F.context(data)))
......
......@@ -309,6 +309,8 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(
std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());
for (dgl_type_t etype = 0 ; etype < g->NumEdgeTypes() ; ++etype) {
auto src_dst_type = g->GetEndpointTypes(etype);
int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
aten::COOMatrix coo;
aten::CSRMatrix csr, csc;
std::string prefix = name + "_" + std::to_string(etype);
......@@ -321,7 +323,8 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(
if (has_csc) {
csc = shm.CopyToSharedMem(hg->GetCSCMatrix(etype), prefix + "_csc");
}
relgraphs[etype] = UnitGraph::CreateHomographFrom(csc, csr, coo, has_csc, has_csr, has_coo);
relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
}
auto ret = std::shared_ptr<HeteroGraph>(
......@@ -361,6 +364,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0 ; etype < metagraph->NumEdges() ; ++etype) {
auto src_dst = metagraph->FindEdge(etype);
int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
aten::COOMatrix coo;
aten::CSRMatrix csr, csc;
std::string prefix = name + "_" + std::to_string(etype);
......@@ -374,7 +379,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
shm.CreateFromSharedMem(&csc, prefix + "_csc");
}
relgraphs[etype] = UnitGraph::CreateHomographFrom(csc, csr, coo, has_csc, has_csr, has_coo);
relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
}
auto ret = std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
......
......@@ -272,18 +272,6 @@ class HeteroGraph : public BaseHeteroGraph {
return relation_graphs_;
}
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
GetRelationGraph(etype)->SetCOOMatrix(0, coo);
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
GetRelationGraph(etype)->SetCSRMatrix(0, csr);
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
GetRelationGraph(etype)->SetCSCMatrix(0, csc);
}
private:
// To create empty class
friend class Serializer;
......
......@@ -635,8 +635,8 @@ HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
if (coo_)
coo = GetCOO()->ToCOOMatrix();
auto g = UnitGraph::CreateHomographFrom(
in_csr, out_csr, coo,
auto g = UnitGraph::CreateUnitGraphFrom(
1, in_csr, out_csr, coo,
in_csr_ != nullptr,
out_csr_ != nullptr,
coo_ != nullptr);
......
......@@ -196,7 +196,12 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
dgl_format_code_t created_formats, allowed_formats;
CHECK(strm->Read(&created_formats)) << "Invalid code for created formats";
CHECK(strm->Read(&allowed_formats)) << "Invalid code for allowed formats";
HeteroGraphPtr relgraph = nullptr;
aten::COOMatrix coo;
aten::CSRMatrix csr;
aten::CSRMatrix csc;
bool has_coo = (created_formats & COO_CODE);
bool has_csr = (created_formats & CSR_CODE);
bool has_csc = (created_formats & CSC_CODE);
if (created_formats & COO_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 2);
......@@ -206,11 +211,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
if (!relgraph)
relgraph = CreateFromCOO(num_vtypes, coo, allowed_formats);
else
relgraph->SetCOOMatrix(0, coo);
coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
}
if (created_formats & CSR_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3);
......@@ -219,11 +220,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
const auto &edge_id = *(array_itr++);
bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
if (!relgraph)
relgraph = CreateFromCSR(num_vtypes, csr, allowed_formats);
else
relgraph->SetCSRMatrix(0, csr);
csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
}
if (created_formats & CSC_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3);
......@@ -232,13 +229,10 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
const auto &edge_id = *(array_itr++);
bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);
if (!relgraph)
relgraph = CreateFromCSC(num_vtypes, csc, allowed_formats);
else
relgraph->SetCSCMatrix(0, csc);
csc = aten::CSRMatrix(num_dst, num_src, indptr, indices, edge_id, sorted);
}
relgraphs[etype] = relgraph;
relgraphs[etype] = UnitGraph::CreateUnitGraphFrom(
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo, allowed_formats);
}
return CreateHeteroGraph(metagraph, relgraphs, num_nodes_per_type);
}
......
......@@ -134,7 +134,7 @@ class UnitGraph::COO : public BaseHeteroGraph {
}
bool IsPinned() const override {
return adj_.row.IsPinned();
return adj_.is_pinned;
}
uint8_t NumBits() const override {
......@@ -359,18 +359,6 @@ class UnitGraph::COO : public BaseHeteroGraph {
return aten::CSRMatrix();
}
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
adj_ = coo;
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
LOG(FATAL) << "Not enabled for COO graph";
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
LOG(FATAL) << "Not enabled for COO graph";
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for COO graph";
return SparseFormat::kCOO;
......@@ -548,7 +536,7 @@ class UnitGraph::CSR : public BaseHeteroGraph {
}
bool IsPinned() const override {
return adj_.indices.IsPinned();
return adj_.is_pinned;
}
uint8_t NumBits() const override {
......@@ -791,18 +779,6 @@ class UnitGraph::CSR : public BaseHeteroGraph {
return adj_;
}
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override {
LOG(FATAL) << "Not enabled for CSR graph";
}
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override {
adj_ = csr;
}
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override {
LOG(FATAL) << "Please use in_csr_->SetCSRMatrix(etype, csc) instead.";
}
SparseFormat SelectFormat(dgl_type_t etype, dgl_format_code_t preferred_formats) const override {
LOG(FATAL) << "Not enabled for CSR graph";
return SparseFormat::kCSR;
......@@ -1370,7 +1346,8 @@ UnitGraph::UnitGraph(GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr c
CHECK(GetAny()) << "At least one graph structure should exist.";
}
HeteroGraphPtr UnitGraph::CreateHomographFrom(
HeteroGraphPtr UnitGraph::CreateUnitGraphFrom(
int num_vtypes,
const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo,
......@@ -1378,7 +1355,7 @@ HeteroGraphPtr UnitGraph::CreateHomographFrom(
bool has_out_csr,
bool has_coo,
dgl_format_code_t formats) {
auto mg = CreateUnitGraphMetaGraph1();
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
CSRPtr in_csr_ptr = nullptr;
CSRPtr out_csr_ptr = nullptr;
......@@ -1512,54 +1489,6 @@ aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const {
return GetCOO()->adj();
}
void UnitGraph::SetCOOMatrix(dgl_type_t etype, COOMatrix coo) {
if (!(formats_ & COO_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set COO matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set COOMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!coo_->defined())
*(const_cast<UnitGraph*>(this)->coo_) = COO(meta_graph(), coo);
else
coo_->SetCOOMatrix(0, coo);
}
void UnitGraph::SetCSRMatrix(dgl_type_t etype, CSRMatrix csr) {
if (!(formats_ & CSR_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set CSR matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set CSRMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!out_csr_->defined())
*(const_cast<UnitGraph*>(this)->out_csr_) = CSR(meta_graph(), csr);
else
out_csr_->SetCSRMatrix(0, csr);
}
void UnitGraph::SetCSCMatrix(dgl_type_t etype, CSRMatrix csc) {
if (!(formats_ & CSC_CODE)) {
LOG(FATAL) << "The graph have restricted sparse format " <<
CodeToStr(formats_) << ", cannot set CSC matrix.";
return;
}
if (IsPinned()) {
LOG(FATAL) << "Cannot set CSCMatrix if the graph is pinned, please unpin the graph.";
return;
}
if (!in_csr_->defined())
*(const_cast<UnitGraph*>(this)->in_csr_) = CSR(meta_graph(), csc);
else
in_csr_->SetCSRMatrix(0, csc);
}
HeteroGraphPtr UnitGraph::GetAny() const {
if (in_csr_->defined()) {
return in_csr_;
......
......@@ -305,14 +305,11 @@ class UnitGraph : public BaseHeteroGraph {
void InvalidateCOO();
void SetCOOMatrix(dgl_type_t etype, aten::COOMatrix coo) override;
void SetCSRMatrix(dgl_type_t etype, aten::CSRMatrix csr) override;
void SetCSCMatrix(dgl_type_t etype, aten::CSRMatrix csc) override;
private:
friend class Serializer;
friend class HeteroGraph;
friend class ImmutableGraph;
friend HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates& states);
// private empty constructor
UnitGraph() {}
......@@ -329,6 +326,7 @@ class UnitGraph : public BaseHeteroGraph {
/*!
* \brief constructor
* \param num_vtypes number of vertex types (1 or 2)
* \param metagraph metagraph
* \param in_csr in edge csr
* \param out_csr out edge csr
......@@ -337,7 +335,8 @@ class UnitGraph : public BaseHeteroGraph {
* \param has_out_csr whether out_csr is valid
* \param has_coo whether coo is valid
*/
static HeteroGraphPtr CreateHomographFrom(
static HeteroGraphPtr CreateUnitGraphFrom(
int num_vtypes,
const aten::CSRMatrix &in_csr,
const aten::CSRMatrix &out_csr,
const aten::COOMatrix &coo,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment