/** * Copyright (c) 2019 by Contributors * @file graph/unit_graph.cc * @brief UnitGraph graph implementation */ #include "./unit_graph.h" #include #include #include #include #include "../c_api_common.h" #include "./serialize/dglstream.h" namespace dgl { namespace { using namespace dgl::aten; // create metagraph of one node type inline GraphPtr CreateUnitGraphMetaGraph1() { // a self-loop edge 0->0 std::vector row_vec(1, 0); std::vector col_vec(1, 0); IdArray row = aten::VecToIdArray(row_vec); IdArray col = aten::VecToIdArray(col_vec); GraphPtr g = ImmutableGraph::CreateFromCOO(1, row, col); return g; } // create metagraph of two node types inline GraphPtr CreateUnitGraphMetaGraph2() { // an edge 0->1 std::vector row_vec(1, 0); std::vector col_vec(1, 1); IdArray row = aten::VecToIdArray(row_vec); IdArray col = aten::VecToIdArray(col_vec); GraphPtr g = ImmutableGraph::CreateFromCOO(2, row, col); return g; } inline GraphPtr CreateUnitGraphMetaGraph(int num_vtypes) { static GraphPtr mg1 = CreateUnitGraphMetaGraph1(); static GraphPtr mg2 = CreateUnitGraphMetaGraph2(); if (num_vtypes == 1) return mg1; else if (num_vtypes == 2) return mg2; else LOG(FATAL) << "Invalid number of vertex types. Must be 1 or 2."; return {}; } }; // namespace ////////////////////////////////////////////////////////// // // COO graph implementation // ////////////////////////////////////////////////////////// class UnitGraph::COO : public BaseHeteroGraph { public: COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src, IdArray dst, bool row_sorted = false, bool col_sorted = false) : BaseHeteroGraph(metagraph) { CHECK(aten::IsValidIdArray(src)); CHECK(aten::IsValidIdArray(dst)); CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length."; adj_ = aten::COOMatrix{num_src, num_dst, src, dst, NullArray(), row_sorted, col_sorted}; } COO(GraphPtr metagraph, const aten::COOMatrix& coo) : BaseHeteroGraph(metagraph), adj_(coo) { // Data index should not be inherited. Edges in COO format are always // assigned ids from 0 to num_edges - 1. CHECK(!COOHasData(coo)) << "[BUG] COO should not contain data."; adj_.data = aten::NullArray(); } COO() { // set magic num_rows/num_cols to mark it as undefined // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is // supported adj_.num_rows = -1; adj_.num_cols = -1; }; bool defined() const { return (adj_.num_rows >= 0) && (adj_.num_cols >= 0); } inline dgl_type_t SrcType() const { return 0; } inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; } inline dgl_type_t EdgeType() const { return 0; } HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. " << "The relation graph is simply this graph itself."; return {}; } void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override { LOG(FATAL) << "UnitGraph graph is not mutable."; } void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override { LOG(FATAL) << "UnitGraph graph is not mutable."; } void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override { LOG(FATAL) << "UnitGraph graph is not mutable."; } void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; } DGLDataType DataType() const override { return adj_.row->dtype; } DGLContext Context() const override { return adj_.row->ctx; } bool IsPinned() const override { return adj_.is_pinned; } uint8_t NumBits() const override { return adj_.row->dtype.bits; } COO AsNumBits(uint8_t bits) const { if (NumBits() == bits) return *this; COO ret( meta_graph_, adj_.num_rows, adj_.num_cols, aten::AsNumBits(adj_.row, bits), aten::AsNumBits(adj_.col, bits)); return ret; } COO CopyTo(const DGLContext& ctx) const { if (Context() == ctx) return *this; return COO(meta_graph_, adj_.CopyTo(ctx)); } /** * @brief Copy the adj_ to pinned memory. * @return COOMatrix of the COO graph. */ COO PinMemory() { if (adj_.is_pinned) return *this; return COO(meta_graph_, adj_.PinMemory()); } /** @brief Pin the adj_: COOMatrix of the COO graph. */ void PinMemory_() { adj_.PinMemory_(); } /** @brief Unpin the adj_: COOMatrix of the COO graph. */ void UnpinMemory_() { adj_.UnpinMemory_(); } /** @brief Record stream for the adj_: COOMatrix of the COO graph. */ void RecordStream(DGLStreamHandle stream) override { adj_.RecordStream(stream); } bool IsMultigraph() const override { return aten::COOHasDuplicate(adj_); } bool IsReadonly() const override { return true; } uint64_t NumVertices(dgl_type_t vtype) const override { if (vtype == SrcType()) { return adj_.num_rows; } else if (vtype == DstType()) { return adj_.num_cols; } else { LOG(FATAL) << "Invalid vertex type: " << vtype; return 0; } } uint64_t NumEdges(dgl_type_t etype) const override { return adj_.row->shape[0]; } bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override { return vid < NumVertices(vtype); } BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override { LOG(FATAL) << "Not enabled for COO graph"; return {}; } bool HasEdgeBetween( dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override { CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src; CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst; return aten::COOIsNonZero(adj_, src, dst); } BoolArray HasEdgesBetween( dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override { CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array."; return aten::COOIsNonZero(adj_, src_ids, dst_ids); } IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override { CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst; return aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), dst).second; } IdArray Successors(dgl_type_t etype, dgl_id_t src) const override { CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src; return aten::COOGetRowDataAndIndices(adj_, src).second; } IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override { CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src; CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst; return aten::COOGetAllData(adj_, src, dst); } EdgeArray EdgeIdsAll( dgl_type_t etype, IdArray src, IdArray dst) const override { CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array."; const auto& arrs = aten::COOGetDataAndIndices(adj_, src, dst); return EdgeArray{arrs[0], arrs[1], arrs[2]}; } IdArray EdgeIdsOne( dgl_type_t etype, IdArray src, IdArray dst) const override { return aten::COOGetData(adj_, src, dst); } std::pair FindEdge( dgl_type_t etype, dgl_id_t eid) const override { CHECK(eid < NumEdges(etype)) << "Invalid edge id: " << eid; const dgl_id_t src = aten::IndexSelect(adj_.row, eid); const dgl_id_t dst = aten::IndexSelect(adj_.col, eid); return std::pair(src, dst); } EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override { CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array"; BUG_IF_FAIL(aten::IsNullArray(adj_.data)) << "FindEdges requires the internal COO matrix not having EIDs."; return EdgeArray{ aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids), eids}; } EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override { IdArray ret_src, ret_eid; std::tie(ret_eid, ret_src) = aten::COOGetRowDataAndIndices(aten::COOTranspose(adj_), vid); IdArray ret_dst = aten::Full(vid, ret_src->shape[0], NumBits(), ret_src->ctx); return EdgeArray{ret_src, ret_dst, ret_eid}; } EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override { CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array."; auto coosubmat = aten::COOSliceRows(aten::COOTranspose(adj_), vids); auto row = aten::IndexSelect(vids, coosubmat.row); return EdgeArray{coosubmat.col, row, coosubmat.data}; } EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override { IdArray ret_dst, ret_eid; std::tie(ret_eid, ret_dst) = aten::COOGetRowDataAndIndices(adj_, vid); IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx); return EdgeArray{ret_src, ret_dst, ret_eid}; } EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override { CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array."; auto coosubmat = aten::COOSliceRows(adj_, vids); auto row = aten::IndexSelect(vids, coosubmat.row); return EdgeArray{row, coosubmat.col, coosubmat.data}; } EdgeArray Edges( dgl_type_t etype, const std::string& order = "") const override { CHECK(order.empty() || order == std::string("eid")) << "COO only support Edges of order \"eid\", but got \"" << order << "\"."; IdArray rst_eid = aten::Range(0, NumEdges(etype), NumBits(), Context()); return EdgeArray{adj_.row, adj_.col, rst_eid}; } uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override { CHECK(HasVertex(DstType(), vid)) << "Invalid dst vertex id: " << vid; return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vid); } DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override { CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array."; return aten::COOGetRowNNZ(aten::COOTranspose(adj_), vids); } uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override { CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid; return aten::COOGetRowNNZ(adj_, vid); } DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override { CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array."; return aten::COOGetRowNNZ(adj_, vids); } DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override { LOG(INFO) << "Not enabled for COO graph."; return {}; } DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override { LOG(INFO) << "Not enabled for COO graph."; return {}; } DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override { LOG(INFO) << "Not enabled for COO graph."; return {}; } DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override { LOG(INFO) << "Not enabled for COO graph."; return {}; } std::vector GetAdj( dgl_type_t etype, bool transpose, const std::string& fmt) const override { CHECK(fmt == "coo") << "Not valid adj format request."; if (transpose) { return {aten::HStack(adj_.col, adj_.row)}; } else { return {aten::HStack(adj_.row, adj_.col)}; } } aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { return adj_; } aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override { LOG(FATAL) << "Not enabled for COO graph"; return aten::CSRMatrix(); } aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { LOG(FATAL) << "Not enabled for COO graph"; return aten::CSRMatrix(); } SparseFormat SelectFormat( dgl_type_t etype, dgl_format_code_t preferred_formats) const override { LOG(FATAL) << "Not enabled for COO graph"; return SparseFormat::kCOO; } dgl_format_code_t GetAllowedFormats() const override { LOG(FATAL) << "Not enabled for COO graph"; return 0; } dgl_format_code_t GetCreatedFormats() const override { LOG(FATAL) << "Not enabled for COO graph"; return 0; } HeteroSubgraph VertexSubgraph( const std::vector& vids) const override { CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch"; auto srcvids = vids[SrcType()], dstvids = vids[DstType()]; CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array."; HeteroSubgraph subg; const auto& submat = aten::COOSliceMatrix(adj_, srcvids, dstvids); DGLContext ctx = aten::GetContextOf(vids); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx); subg.graph = std::make_shared( meta_graph(), submat.num_rows, submat.num_cols, submat.row, submat.col); subg.induced_vertices = vids; subg.induced_edges.emplace_back(submat.data); return subg; } HeteroSubgraph EdgeSubgraph( const std::vector& eids, bool preserve_nodes = false) const override { CHECK_EQ(eids.size(), 1) << "Edge type number mismatch."; HeteroSubgraph subg; if (!preserve_nodes) { IdArray new_src = aten::IndexSelect(adj_.row, eids[0]); IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]); subg.induced_vertices.emplace_back(aten::Relabel_({new_src})); subg.induced_vertices.emplace_back(aten::Relabel_({new_dst})); const auto new_nsrc = subg.induced_vertices[0]->shape[0]; const auto new_ndst = subg.induced_vertices[1]->shape[0]; subg.graph = std::make_shared( meta_graph(), new_nsrc, new_ndst, new_src, new_dst); subg.induced_edges = eids; } else { IdArray new_src = aten::IndexSelect(adj_.row, eids[0]); IdArray new_dst = aten::IndexSelect(adj_.col, eids[0]); subg.induced_vertices.emplace_back( aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context())); subg.induced_vertices.emplace_back( aten::NullArray(DGLDataType{kDGLInt, NumBits(), 1}, Context())); subg.graph = std::make_shared( meta_graph(), NumVertices(SrcType()), NumVertices(DstType()), new_src, new_dst); subg.induced_edges = eids; } return subg; } HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override { LOG(FATAL) << "Not enabled for COO graph."; return nullptr; } aten::COOMatrix adj() const { return adj_; } /** * @brief Determines whether the graph is "hypersparse", i.e. having * significantly more nodes than edges. */ bool IsHypersparse() const { return (NumVertices(SrcType()) / 8 > NumEdges(EdgeType())) && (NumVertices(SrcType()) > 1000000); } bool Load(dmlc::Stream* fs) { auto meta_imgraph = Serializer::make_shared(); CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph"; meta_graph_ = meta_imgraph; CHECK(fs->Read(&adj_)) << "Invalid adj matrix"; return true; } void Save(dmlc::Stream* fs) const { auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph()); fs->Write(meta_graph_ptr); fs->Write(adj_); } private: friend class Serializer; /** @brief internal adjacency matrix. Data array is empty */ aten::COOMatrix adj_; }; ////////////////////////////////////////////////////////// // // CSR graph implementation // ////////////////////////////////////////////////////////// /** @brief CSR graph */ class UnitGraph::CSR : public BaseHeteroGraph { public: CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray indptr, IdArray indices, IdArray edge_ids) : BaseHeteroGraph(metagraph) { CHECK(aten::IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indices)); if (aten::IsValidIdArray(edge_ids)) CHECK( (indices->shape[0] == edge_ids->shape[0]) || aten::IsNullArray(edge_ids)) << "edge id arrays should have the same length as indices if not " "empty"; CHECK_EQ(num_src, indptr->shape[0] - 1) << "number of nodes do not match the length of indptr minus 1."; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; } CSR(GraphPtr metagraph, const aten::CSRMatrix& csr) : BaseHeteroGraph(metagraph), adj_(csr) {} CSR() { // set magic num_rows/num_cols to mark it as undefined // adj_.num_rows == 0 and adj_.num_cols == 0 means empty UnitGraph which is // supported adj_.num_rows = -1; adj_.num_cols = -1; }; bool defined() const { return (adj_.num_rows >= 0) || (adj_.num_cols >= 0); } inline dgl_type_t SrcType() const { return 0; } inline dgl_type_t DstType() const { return NumVertexTypes() == 1 ? 0 : 1; } inline dgl_type_t EdgeType() const { return 0; } HeteroGraphPtr GetRelationGraph(dgl_type_t etype) const override { LOG(FATAL) << "The method shouldn't be called for UnitGraph graph. " << "The relation graph is simply this graph itself."; return {}; } void AddVertices(dgl_type_t vtype, uint64_t num_vertices) override { LOG(FATAL) << "UnitGraph graph is not mutable."; } void AddEdge(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) override { LOG(FATAL) << "UnitGraph graph is not mutable."; } void AddEdges(dgl_type_t etype, IdArray src_ids, IdArray dst_ids) override { LOG(FATAL) << "UnitGraph graph is not mutable."; } void Clear() override { LOG(FATAL) << "UnitGraph graph is not mutable."; } DGLDataType DataType() const override { return adj_.indices->dtype; } DGLContext Context() const override { return adj_.indices->ctx; } bool IsPinned() const override { return adj_.is_pinned; } uint8_t NumBits() const override { return adj_.indices->dtype.bits; } CSR AsNumBits(uint8_t bits) const { if (NumBits() == bits) { return *this; } else { CSR ret( meta_graph_, adj_.num_rows, adj_.num_cols, aten::AsNumBits(adj_.indptr, bits), aten::AsNumBits(adj_.indices, bits), aten::AsNumBits(adj_.data, bits)); return ret; } } CSR CopyTo(const DGLContext& ctx) const { if (Context() == ctx) { return *this; } else { return CSR(meta_graph_, adj_.CopyTo(ctx)); } } /** * @brief Copy the adj_ to pinned memory. * @return CSRMatrix of the CSR graph. */ CSR PinMemory() { if (adj_.is_pinned) return *this; return CSR(meta_graph_, adj_.PinMemory()); } /** @brief Pin the adj_: CSRMatrix of the CSR graph. */ void PinMemory_() { adj_.PinMemory_(); } /** @brief Unpin the adj_: CSRMatrix of the CSR graph. */ void UnpinMemory_() { adj_.UnpinMemory_(); } /** @brief Record stream for the adj_: CSRMatrix of the CSR graph. */ void RecordStream(DGLStreamHandle stream) override { adj_.RecordStream(stream); } bool IsMultigraph() const override { return aten::CSRHasDuplicate(adj_); } bool IsReadonly() const override { return true; } uint64_t NumVertices(dgl_type_t vtype) const override { if (vtype == SrcType()) { return adj_.num_rows; } else if (vtype == DstType()) { return adj_.num_cols; } else { LOG(FATAL) << "Invalid vertex type: " << vtype; return 0; } } uint64_t NumEdges(dgl_type_t etype) const override { return adj_.indices->shape[0]; } bool HasVertex(dgl_type_t vtype, dgl_id_t vid) const override { return vid < NumVertices(vtype); } BoolArray HasVertices(dgl_type_t vtype, IdArray vids) const override { LOG(FATAL) << "Not enabled for COO graph"; return {}; } bool HasEdgeBetween( dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override { CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src; CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst; return aten::CSRIsNonZero(adj_, src, dst); } BoolArray HasEdgesBetween( dgl_type_t etype, IdArray src_ids, IdArray dst_ids) const override { CHECK(aten::IsValidIdArray(src_ids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst_ids)) << "Invalid vertex id array."; return aten::CSRIsNonZero(adj_, src_ids, dst_ids); } IdArray Predecessors(dgl_type_t etype, dgl_id_t dst) const override { LOG(INFO) << "Not enabled for CSR graph."; return {}; } IdArray Successors(dgl_type_t etype, dgl_id_t src) const override { CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src; return aten::CSRGetRowColumnIndices(adj_, src); } IdArray EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const override { CHECK(HasVertex(SrcType(), src)) << "Invalid src vertex id: " << src; CHECK(HasVertex(DstType(), dst)) << "Invalid dst vertex id: " << dst; return aten::CSRGetAllData(adj_, src, dst); } EdgeArray EdgeIdsAll( dgl_type_t etype, IdArray src, IdArray dst) const override { CHECK(aten::IsValidIdArray(src)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dst)) << "Invalid vertex id array."; const auto& arrs = aten::CSRGetDataAndIndices(adj_, src, dst); return EdgeArray{arrs[0], arrs[1], arrs[2]}; } IdArray EdgeIdsOne( dgl_type_t etype, IdArray src, IdArray dst) const override { return aten::CSRGetData(adj_, src, dst); } std::pair FindEdge( dgl_type_t etype, dgl_id_t eid) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } EdgeArray FindEdges(dgl_type_t etype, IdArray eids) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } EdgeArray InEdges(dgl_type_t etype, dgl_id_t vid) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } EdgeArray InEdges(dgl_type_t etype, IdArray vids) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } EdgeArray OutEdges(dgl_type_t etype, dgl_id_t vid) const override { CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid; IdArray ret_dst = aten::CSRGetRowColumnIndices(adj_, vid); IdArray ret_eid = aten::CSRGetRowData(adj_, vid); IdArray ret_src = aten::Full(vid, ret_dst->shape[0], NumBits(), ret_dst->ctx); return EdgeArray{ret_src, ret_dst, ret_eid}; } EdgeArray OutEdges(dgl_type_t etype, IdArray vids) const override { CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array."; auto csrsubmat = aten::CSRSliceRows(adj_, vids); auto coosubmat = aten::CSRToCOO(csrsubmat, false); // Note that the row id in the csr submat is relabled, so // we need to recover it using an index select. auto row = aten::IndexSelect(vids, coosubmat.row); return EdgeArray{row, coosubmat.col, coosubmat.data}; } EdgeArray Edges( dgl_type_t etype, const std::string& order = "") const override { CHECK(order.empty() || order == std::string("srcdst")) << "CSR only support Edges of order \"srcdst\"," << " but got \"" << order << "\"."; auto coo = aten::CSRToCOO(adj_, false); if (order == std::string("srcdst")) { // make sure the coo is sorted if an order is requested coo = aten::COOSort(coo, true); } return EdgeArray{coo.row, coo.col, coo.data}; } uint64_t InDegree(dgl_type_t etype, dgl_id_t vid) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } DegreeArray InDegrees(dgl_type_t etype, IdArray vids) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } uint64_t OutDegree(dgl_type_t etype, dgl_id_t vid) const override { CHECK(HasVertex(SrcType(), vid)) << "Invalid src vertex id: " << vid; return aten::CSRGetRowNNZ(adj_, vid); } DegreeArray OutDegrees(dgl_type_t etype, IdArray vids) const override { CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array."; return aten::CSRGetRowNNZ(adj_, vids); } DGLIdIters SuccVec(dgl_type_t etype, dgl_id_t vid) const override { // TODO(minjie): This still assumes the data type and device context // of this graph. Should fix later. CHECK_EQ(NumBits(), 64); const dgl_id_t* indptr_data = static_cast(adj_.indptr->data); const dgl_id_t* indices_data = static_cast(adj_.indices->data); const dgl_id_t start = indptr_data[vid]; const dgl_id_t end = indptr_data[vid + 1]; return DGLIdIters(indices_data + start, indices_data + end); } DGLIdIters32 SuccVec32(dgl_type_t etype, dgl_id_t vid) { // TODO(minjie): This still assumes the data type and device context // of this graph. Should fix later. const int32_t* indptr_data = static_cast(adj_.indptr->data); const int32_t* indices_data = static_cast(adj_.indices->data); const int32_t start = indptr_data[vid]; const int32_t end = indptr_data[vid + 1]; return DGLIdIters32(indices_data + start, indices_data + end); } DGLIdIters OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const override { // TODO(minjie): This still assumes the data type and device context // of this graph. Should fix later. CHECK_EQ(NumBits(), 64); const dgl_id_t* indptr_data = static_cast(adj_.indptr->data); const dgl_id_t* eid_data = static_cast(adj_.data->data); const dgl_id_t start = indptr_data[vid]; const dgl_id_t end = indptr_data[vid + 1]; return DGLIdIters(eid_data + start, eid_data + end); } DGLIdIters PredVec(dgl_type_t etype, dgl_id_t vid) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } DGLIdIters InEdgeVec(dgl_type_t etype, dgl_id_t vid) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } std::vector GetAdj( dgl_type_t etype, bool transpose, const std::string& fmt) const override { CHECK(!transpose && fmt == "csr") << "Not valid adj format request."; return {adj_.indptr, adj_.indices, adj_.data}; } aten::COOMatrix GetCOOMatrix(dgl_type_t etype) const override { LOG(FATAL) << "Not enabled for CSR graph"; return aten::COOMatrix(); } aten::CSRMatrix GetCSCMatrix(dgl_type_t etype) const override { LOG(FATAL) << "Not enabled for CSR graph"; return aten::CSRMatrix(); } aten::CSRMatrix GetCSRMatrix(dgl_type_t etype) const override { return adj_; } SparseFormat SelectFormat( dgl_type_t etype, dgl_format_code_t preferred_formats) const override { LOG(FATAL) << "Not enabled for CSR graph"; return SparseFormat::kCSR; } dgl_format_code_t GetAllowedFormats() const override { LOG(FATAL) << "Not enabled for COO graph"; return 0; } dgl_format_code_t GetCreatedFormats() const override { LOG(FATAL) << "Not enabled for CSR graph"; return 0; } HeteroSubgraph VertexSubgraph( const std::vector& vids) const override { CHECK_EQ(vids.size(), NumVertexTypes()) << "Number of vertex types mismatch"; auto srcvids = vids[SrcType()], dstvids = vids[DstType()]; CHECK(aten::IsValidIdArray(srcvids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(dstvids)) << "Invalid vertex id array."; HeteroSubgraph subg; const auto& submat = aten::CSRSliceMatrix(adj_, srcvids, dstvids); DGLContext ctx = aten::GetContextOf(vids); IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), ctx); subg.graph = std::make_shared( meta_graph(), submat.num_rows, submat.num_cols, submat.indptr, submat.indices, sub_eids); subg.induced_vertices = vids; subg.induced_edges.emplace_back(submat.data); return subg; } HeteroSubgraph EdgeSubgraph( const std::vector& eids, bool preserve_nodes = false) const override { LOG(FATAL) << "Not enabled for CSR graph."; return {}; } HeteroGraphPtr GetGraphInFormat(dgl_format_code_t formats) const override { LOG(FATAL) << "Not enabled for CSR graph."; return nullptr; } aten::CSRMatrix adj() const { return adj_; } bool Load(dmlc::Stream* fs) { auto meta_imgraph = Serializer::make_shared(); CHECK(fs->Read(&meta_imgraph)) << "Invalid meta graph"; meta_graph_ = meta_imgraph; CHECK(fs->Read(&adj_)) << "Invalid adj matrix"; return true; } void Save(dmlc::Stream* fs) const { auto meta_graph_ptr = ImmutableGraph::ToImmutable(meta_graph()); fs->Write(meta_graph_ptr); fs->Write(adj_); } private: friend class Serializer; /** @brief internal adjacency matrix. Data array stores edge ids */ aten::CSRMatrix adj_; }; ////////////////////////////////////////////////////////// // // unit graph implementation // ////////////////////////////////////////////////////////// DGLDataType UnitGraph::DataType() const { return GetAny()->DataType(); } DGLContext UnitGraph::Context() const { return GetAny()->Context(); } bool UnitGraph::IsPinned() const { return GetAny()->IsPinned(); } uint8_t UnitGraph::NumBits() const { return GetAny()->NumBits(); } bool UnitGraph::IsMultigraph() const { const SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); return ptr->IsMultigraph(); } uint64_t UnitGraph::NumVertices(dgl_type_t vtype) const { const SparseFormat fmt = SelectFormat(ALL_CODE); const auto ptr = GetFormat(fmt); // TODO(BarclayII): we have a lot of special handling for CSC. // Need to have a UnitGraph::CSC backend instead. if (fmt == SparseFormat::kCSC) vtype = (vtype == SrcType()) ? DstType() : SrcType(); return ptr->NumVertices(vtype); } uint64_t UnitGraph::NumEdges(dgl_type_t etype) const { return GetAny()->NumEdges(etype); } bool UnitGraph::HasVertex(dgl_type_t vtype, dgl_id_t vid) const { const SparseFormat fmt = SelectFormat(ALL_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) vtype = (vtype == SrcType()) ? DstType() : SrcType(); return ptr->HasVertex(vtype, vid); } BoolArray UnitGraph::HasVertices(dgl_type_t vtype, IdArray vids) const { CHECK(aten::IsValidIdArray(vids)) << "Invalid id array input"; return aten::LT(vids, NumVertices(vtype)); } bool UnitGraph::HasEdgeBetween( dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { const SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) return ptr->HasEdgeBetween(etype, dst, src); else return ptr->HasEdgeBetween(etype, src, dst); } BoolArray UnitGraph::HasEdgesBetween( dgl_type_t etype, IdArray src, IdArray dst) const { const SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) return ptr->HasEdgesBetween(etype, dst, src); else return ptr->HasEdgesBetween(etype, src, dst); } IdArray UnitGraph::Predecessors(dgl_type_t etype, dgl_id_t dst) const { const SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) return ptr->Successors(etype, dst); else return ptr->Predecessors(etype, dst); } IdArray UnitGraph::Successors(dgl_type_t etype, dgl_id_t src) const { const SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); return ptr->Successors(etype, src); } IdArray UnitGraph::EdgeId(dgl_type_t etype, dgl_id_t src, dgl_id_t dst) const { const SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) return ptr->EdgeId(etype, dst, src); else return ptr->EdgeId(etype, src, dst); } EdgeArray UnitGraph::EdgeIdsAll( dgl_type_t etype, IdArray src, IdArray dst) const { const SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) { EdgeArray edges = ptr->EdgeIdsAll(etype, dst, src); return EdgeArray{edges.dst, edges.src, edges.id}; } else { return ptr->EdgeIdsAll(etype, src, dst); } } IdArray UnitGraph::EdgeIdsOne( dgl_type_t etype, IdArray src, IdArray dst) const { const SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) { return ptr->EdgeIdsOne(etype, dst, src); } else { return ptr->EdgeIdsOne(etype, src, dst); } } std::pair UnitGraph::FindEdge( dgl_type_t etype, dgl_id_t eid) const { const SparseFormat fmt = SelectFormat(COO_CODE); const auto ptr = GetFormat(fmt); return ptr->FindEdge(etype, eid); } EdgeArray UnitGraph::FindEdges(dgl_type_t etype, IdArray eids) const { const SparseFormat fmt = SelectFormat(COO_CODE); const auto ptr = GetFormat(fmt); return ptr->FindEdges(etype, eids); } EdgeArray UnitGraph::InEdges(dgl_type_t etype, dgl_id_t vid) const { const SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) { const EdgeArray& ret = ptr->OutEdges(etype, vid); return {ret.dst, ret.src, ret.id}; } else { return ptr->InEdges(etype, vid); } } EdgeArray UnitGraph::InEdges(dgl_type_t etype, IdArray vids) const { const SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) { const EdgeArray& ret = ptr->OutEdges(etype, vids); return {ret.dst, ret.src, ret.id}; } else { return ptr->InEdges(etype, vids); } } EdgeArray UnitGraph::OutEdges(dgl_type_t etype, dgl_id_t vid) const { const SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); return ptr->OutEdges(etype, vid); } EdgeArray UnitGraph::OutEdges(dgl_type_t etype, IdArray vids) const { const SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); return ptr->OutEdges(etype, vids); } EdgeArray UnitGraph::Edges(dgl_type_t etype, const std::string& order) const { SparseFormat fmt; if (order == std::string("eid")) { fmt = SelectFormat(COO_CODE); } else if (order.empty()) { // arbitrary order fmt = SelectFormat(ALL_CODE); } else if (order == std::string("srcdst")) { fmt = SelectFormat(CSR_CODE); } else { LOG(FATAL) << "Unsupported order request: " << order; return {}; } const auto& edges = GetFormat(fmt)->Edges(etype, order); if (fmt == SparseFormat::kCSC) return EdgeArray{edges.dst, edges.src, edges.id}; else return edges; } uint64_t UnitGraph::InDegree(dgl_type_t etype, dgl_id_t vid) const { SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO) << "In degree cannot be computed as neither CSC nor COO format is " "allowed for this graph. Please enable one of them at least."; return fmt == SparseFormat::kCSC ? ptr->OutDegree(etype, vid) : ptr->InDegree(etype, vid); } DegreeArray UnitGraph::InDegrees(dgl_type_t etype, IdArray vids) const { SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); CHECK(fmt == SparseFormat::kCSC || fmt == SparseFormat::kCOO) << "In degree cannot be computed as neither CSC nor COO format is " "allowed for this graph. Please enable one of them at least."; return fmt == SparseFormat::kCSC ? ptr->OutDegrees(etype, vids) : ptr->InDegrees(etype, vids); } uint64_t UnitGraph::OutDegree(dgl_type_t etype, dgl_id_t vid) const { SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO) << "Out degree cannot be computed as neither CSR nor COO format is " "allowed for this graph. Please enable one of them at least."; return ptr->OutDegree(etype, vid); } DegreeArray UnitGraph::OutDegrees(dgl_type_t etype, IdArray vids) const { SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); CHECK(fmt == SparseFormat::kCSR || fmt == SparseFormat::kCOO) << "Out degree cannot be computed as neither CSR nor COO format is " "allowed for this graph. Please enable one of them at least."; return ptr->OutDegrees(etype, vids); } DGLIdIters UnitGraph::SuccVec(dgl_type_t etype, dgl_id_t vid) const { SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); return ptr->SuccVec(etype, vid); } DGLIdIters32 UnitGraph::SuccVec32(dgl_type_t etype, dgl_id_t vid) const { SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = std::dynamic_pointer_cast(GetFormat(fmt)); CHECK_NOTNULL(ptr); return ptr->SuccVec32(etype, vid); } DGLIdIters UnitGraph::OutEdgeVec(dgl_type_t etype, dgl_id_t vid) const { SparseFormat fmt = SelectFormat(CSR_CODE); const auto ptr = GetFormat(fmt); return ptr->OutEdgeVec(etype, vid); } DGLIdIters UnitGraph::PredVec(dgl_type_t etype, dgl_id_t vid) const { SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) return ptr->SuccVec(etype, vid); else return ptr->PredVec(etype, vid); } DGLIdIters UnitGraph::InEdgeVec(dgl_type_t etype, dgl_id_t vid) const { SparseFormat fmt = SelectFormat(CSC_CODE); const auto ptr = GetFormat(fmt); if (fmt == SparseFormat::kCSC) return ptr->OutEdgeVec(etype, vid); else return ptr->InEdgeVec(etype, vid); } std::vector UnitGraph::GetAdj( dgl_type_t etype, bool transpose, const std::string& fmt) const { // TODO(minjie): Our current semantics of adjacency matrix is row for dst // nodes and col for src nodes. Therefore, we need to flip the transpose flag. // For example, // transpose=False is equal to in edge CSR. We have this behavior because // previously we use framework's SPMM and we don't cache reverse adj. This // is not intuitive and also not consistent with networkx's // to_scipy_sparse_matrix. With the upcoming custom kernel change, we should // change the behavior and make row for src and col for dst. if (fmt == std::string("csr")) { return !transpose ? GetOutCSR()->GetAdj(etype, false, "csr") : GetInCSR()->GetAdj(etype, false, "csr"); } else if (fmt == std::string("coo")) { return GetCOO()->GetAdj(etype, transpose, fmt); } else { LOG(FATAL) << "unsupported adjacency matrix format: " << fmt; return {}; } } HeteroSubgraph UnitGraph::VertexSubgraph( const std::vector& vids) const { // We prefer to generate a subgraph from out-csr. SparseFormat fmt = SelectFormat(CSR_CODE); HeteroSubgraph sg = GetFormat(fmt)->VertexSubgraph(vids); HeteroSubgraph ret; CSRPtr subcsr = nullptr; CSRPtr subcsc = nullptr; COOPtr subcoo = nullptr; switch (fmt) { case SparseFormat::kCSR: subcsr = std::dynamic_pointer_cast(sg.graph); break; case SparseFormat::kCSC: subcsc = std::dynamic_pointer_cast(sg.graph); break; case SparseFormat::kCOO: subcoo = std::dynamic_pointer_cast(sg.graph); break; default: LOG(FATAL) << "[BUG] unsupported format " << static_cast(fmt); return ret; } ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo)); ret.induced_vertices = std::move(sg.induced_vertices); ret.induced_edges = std::move(sg.induced_edges); return ret; } HeteroSubgraph UnitGraph::EdgeSubgraph( const std::vector& eids, bool preserve_nodes) const { SparseFormat fmt = SelectFormat(COO_CODE); auto sg = GetFormat(fmt)->EdgeSubgraph(eids, preserve_nodes); HeteroSubgraph ret; CSRPtr subcsr = nullptr; CSRPtr subcsc = nullptr; COOPtr subcoo = nullptr; switch (fmt) { case SparseFormat::kCSR: subcsr = std::dynamic_pointer_cast(sg.graph); break; case SparseFormat::kCSC: subcsc = std::dynamic_pointer_cast(sg.graph); break; case SparseFormat::kCOO: subcoo = std::dynamic_pointer_cast(sg.graph); break; default: LOG(FATAL) << "[BUG] unsupported format " << static_cast(fmt); return ret; } ret.graph = HeteroGraphPtr(new UnitGraph(meta_graph(), subcsc, subcsr, subcoo)); ret.induced_vertices = std::move(sg.induced_vertices); ret.induced_edges = std::move(sg.induced_edges); return ret; } HeteroGraphPtr UnitGraph::CreateFromCOO( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row, IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) { CHECK(num_vtypes == 1 || num_vtypes == 2); if (num_vtypes == 1) CHECK_EQ(num_src, num_dst); auto mg = CreateUnitGraphMetaGraph(num_vtypes); COOPtr coo(new COO(mg, num_src, num_dst, row, col, row_sorted, col_sorted)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats)); } HeteroGraphPtr UnitGraph::CreateFromCOO( int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) { CHECK(num_vtypes == 1 || num_vtypes == 2); if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols); auto mg = CreateUnitGraphMetaGraph(num_vtypes); COOPtr coo(new COO(mg, mat)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, nullptr, coo, formats)); } HeteroGraphPtr UnitGraph::CreateFromCSR( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) { CHECK(num_vtypes == 1 || num_vtypes == 2); if (num_vtypes == 1) CHECK_EQ(num_src, num_dst); auto mg = CreateUnitGraphMetaGraph(num_vtypes); CSRPtr csr(new CSR(mg, num_src, num_dst, indptr, indices, edge_ids)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats)); } HeteroGraphPtr UnitGraph::CreateFromCSR( int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) { CHECK(num_vtypes == 1 || num_vtypes == 2); if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols); auto mg = CreateUnitGraphMetaGraph(num_vtypes); CSRPtr csr(new CSR(mg, mat)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats)); } HeteroGraphPtr UnitGraph::CreateFromCSRAndCOO( int64_t num_vtypes, const aten::CSRMatrix& csr, const aten::COOMatrix& coo, dgl_format_code_t formats) { CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK_EQ(coo.num_rows, csr.num_rows); CHECK_EQ(coo.num_cols, csr.num_cols); if (num_vtypes == 1) { CHECK_EQ(csr.num_rows, csr.num_cols); } auto mg = CreateUnitGraphMetaGraph(num_vtypes); CSRPtr csrPtr(new CSR(mg, csr)); COOPtr cooPtr(new COO(mg, coo)); return HeteroGraphPtr(new UnitGraph(mg, nullptr, csrPtr, cooPtr, formats)); } HeteroGraphPtr UnitGraph::CreateFromCSC( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) { CHECK(num_vtypes == 1 || num_vtypes == 2); if (num_vtypes == 1) CHECK_EQ(num_src, num_dst); auto mg = CreateUnitGraphMetaGraph(num_vtypes); CSRPtr csc(new CSR(mg, num_dst, num_src, indptr, indices, edge_ids)); return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); } HeteroGraphPtr UnitGraph::CreateFromCSC( int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) { CHECK(num_vtypes == 1 || num_vtypes == 2); if (num_vtypes == 1) CHECK_EQ(mat.num_rows, mat.num_cols); auto mg = CreateUnitGraphMetaGraph(num_vtypes); CSRPtr csc(new CSR(mg, mat)); return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); } HeteroGraphPtr UnitGraph::CreateFromCSCAndCOO( int64_t num_vtypes, const aten::CSRMatrix& csc, const aten::COOMatrix& coo, dgl_format_code_t formats) { CHECK(num_vtypes == 1 || num_vtypes == 2); CHECK_EQ(coo.num_rows, csc.num_cols); CHECK_EQ(coo.num_cols, csc.num_rows); if (num_vtypes == 1) { CHECK_EQ(csc.num_rows, csc.num_cols); } auto mg = CreateUnitGraphMetaGraph(num_vtypes); CSRPtr cscPtr(new CSR(mg, csc)); COOPtr cooPtr(new COO(mg, coo)); return HeteroGraphPtr(new UnitGraph(mg, cscPtr, nullptr, cooPtr, formats)); } HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { if (g->NumBits() == bits) { return g; } else { auto bg = std::dynamic_pointer_cast(g); CHECK_NOTNULL(bg); CSRPtr new_incsr = (bg->in_csr_->defined()) ? CSRPtr(new CSR(bg->in_csr_->AsNumBits(bits))) : nullptr; CSRPtr new_outcsr = (bg->out_csr_->defined()) ? CSRPtr(new CSR(bg->out_csr_->AsNumBits(bits))) : nullptr; COOPtr new_coo = (bg->coo_->defined()) ? COOPtr(new COO(bg->coo_->AsNumBits(bits))) : nullptr; return HeteroGraphPtr(new UnitGraph( g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_)); } } HeteroGraphPtr UnitGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) { if (ctx == g->Context()) { return g; } else { auto bg = std::dynamic_pointer_cast(g); CHECK_NOTNULL(bg); CSRPtr new_incsr = (bg->in_csr_->defined()) ? CSRPtr(new CSR(bg->in_csr_->CopyTo(ctx))) : nullptr; CSRPtr new_outcsr = (bg->out_csr_->defined()) ? CSRPtr(new CSR(bg->out_csr_->CopyTo(ctx))) : nullptr; COOPtr new_coo = (bg->coo_->defined()) ? COOPtr(new COO(bg->coo_->CopyTo(ctx))) : nullptr; return HeteroGraphPtr(new UnitGraph( g->meta_graph(), new_incsr, new_outcsr, new_coo, bg->formats_)); } } HeteroGraphPtr UnitGraph::PinMemory() { CSRPtr pinned_in_csr, pinned_out_csr; COOPtr pinned_coo; if (this->in_csr_->defined() && this->in_csr_->IsPinned()) { pinned_in_csr = this->in_csr_; } else if (this->in_csr_->defined()) { pinned_in_csr = CSRPtr(new CSR(this->in_csr_->PinMemory())); } else { pinned_in_csr = nullptr; } if (this->out_csr_->defined() && this->out_csr_->IsPinned()) { pinned_out_csr = this->out_csr_; } else if (this->out_csr_->defined()) { pinned_out_csr = CSRPtr(new CSR(this->out_csr_->PinMemory())); } else { pinned_out_csr = nullptr; } if (this->coo_->defined() && this->coo_->IsPinned()) { pinned_coo = this->coo_; } else if (this->coo_->defined()) { pinned_coo = COOPtr(new COO(this->coo_->PinMemory())); } else { pinned_coo = nullptr; } return HeteroGraphPtr(new UnitGraph( meta_graph(), pinned_in_csr, pinned_out_csr, pinned_coo, this->formats_)); } void UnitGraph::PinMemory_() { if (this->in_csr_->defined()) this->in_csr_->PinMemory_(); if (this->out_csr_->defined()) this->out_csr_->PinMemory_(); if (this->coo_->defined()) this->coo_->PinMemory_(); } void UnitGraph::UnpinMemory_() { if (this->in_csr_->defined()) this->in_csr_->UnpinMemory_(); if (this->out_csr_->defined()) this->out_csr_->UnpinMemory_(); if (this->coo_->defined()) this->coo_->UnpinMemory_(); } void UnitGraph::RecordStream(DGLStreamHandle stream) { if (this->in_csr_->defined()) this->in_csr_->RecordStream(stream); if (this->out_csr_->defined()) this->out_csr_->RecordStream(stream); if (this->coo_->defined()) this->coo_->RecordStream(stream); this->recorded_streams.push_back(stream); } void UnitGraph::InvalidateCSR() { this->out_csr_ = CSRPtr(new CSR()); } void UnitGraph::InvalidateCSC() { this->in_csr_ = CSRPtr(new CSR()); } void UnitGraph::InvalidateCOO() { this->coo_ = COOPtr(new COO()); } UnitGraph::UnitGraph( GraphPtr metagraph, CSRPtr in_csr, CSRPtr out_csr, COOPtr coo, dgl_format_code_t formats) : BaseHeteroGraph(metagraph), in_csr_(in_csr), out_csr_(out_csr), coo_(coo) { if (!in_csr_) { in_csr_ = CSRPtr(new CSR()); } if (!out_csr_) { out_csr_ = CSRPtr(new CSR()); } if (!coo_) { coo_ = COOPtr(new COO()); } formats_ = formats; dgl_format_code_t created = GetCreatedFormats(); if ((formats | created) != formats) LOG(FATAL) << "Graph created from formats: " << CodeToStr(created) << ", which is not compatible with available formats: " << CodeToStr(formats); CHECK(GetAny()) << "At least one graph structure should exist."; } HeteroGraphPtr UnitGraph::CreateUnitGraphFrom( int num_vtypes, const aten::CSRMatrix& in_csr, const aten::CSRMatrix& out_csr, const aten::COOMatrix& coo, bool has_in_csr, bool has_out_csr, bool has_coo, dgl_format_code_t formats) { auto mg = CreateUnitGraphMetaGraph(num_vtypes); CSRPtr in_csr_ptr = nullptr; CSRPtr out_csr_ptr = nullptr; COOPtr coo_ptr = nullptr; if (has_in_csr) in_csr_ptr = CSRPtr(new CSR(mg, in_csr)); else in_csr_ptr = CSRPtr(new CSR()); if (has_out_csr) out_csr_ptr = CSRPtr(new CSR(mg, out_csr)); else out_csr_ptr = CSRPtr(new CSR()); if (has_coo) coo_ptr = COOPtr(new COO(mg, coo)); else coo_ptr = COOPtr(new COO()); return HeteroGraphPtr( new UnitGraph(mg, in_csr_ptr, out_csr_ptr, coo_ptr, formats)); } UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { if (inplace) if (!(formats_ & CSC_CODE)) LOG(FATAL) << "The graph have restricted sparse format " << CodeToStr(formats_) << ", cannot create CSC matrix."; CSRPtr ret = in_csr_; // Prefers converting from COO since it is parallelized. // TODO(BarclayII): need benchmarking. if (!in_csr_->defined()) { if (coo_->defined()) { const auto& newadj = aten::COOToCSR(aten::COOTranspose(coo_->adj())); if (inplace) *(const_cast(this)->in_csr_) = CSR(meta_graph(), newadj); else ret = std::make_shared(meta_graph(), newadj); } else { CHECK(out_csr_->defined()) << "None of CSR, COO exist"; const auto& newadj = aten::CSRTranspose(out_csr_->adj()); if (inplace) *(const_cast(this)->in_csr_) = CSR(meta_graph(), newadj); else ret = std::make_shared(meta_graph(), newadj); } if (inplace) { if (IsPinned()) in_csr_->PinMemory_(); for (auto stream : recorded_streams) in_csr_->RecordStream(stream); } } return ret; } /** @brief Return out csr. If not exist, transpose the other one.*/ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { if (inplace) if (!(formats_ & CSR_CODE)) LOG(FATAL) << "The graph have restricted sparse format " << CodeToStr(formats_) << ", cannot create CSR matrix."; CSRPtr ret = out_csr_; // Prefers converting from COO since it is parallelized. // TODO(BarclayII): need benchmarking. if (!out_csr_->defined()) { if (coo_->defined()) { const auto& newadj = aten::COOToCSR(coo_->adj()); if (inplace) *(const_cast(this)->out_csr_) = CSR(meta_graph(), newadj); else ret = std::make_shared(meta_graph(), newadj); } else { CHECK(in_csr_->defined()) << "None of CSR, COO exist"; const auto& newadj = aten::CSRTranspose(in_csr_->adj()); if (inplace) *(const_cast(this)->out_csr_) = CSR(meta_graph(), newadj); else ret = std::make_shared(meta_graph(), newadj); } if (inplace) { if (IsPinned()) out_csr_->PinMemory_(); for (auto stream : recorded_streams) out_csr_->RecordStream(stream); } } return ret; } /** @brief Return coo. If not exist, create from csr.*/ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { if (inplace) if (!(formats_ & COO_CODE)) LOG(FATAL) << "The graph have restricted sparse format " << CodeToStr(formats_) << ", cannot create COO matrix."; COOPtr ret = coo_; if (!coo_->defined()) { if (in_csr_->defined()) { const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true)); if (inplace) *(const_cast(this)->coo_) = COO(meta_graph(), newadj); else ret = std::make_shared(meta_graph(), newadj); } else { CHECK(out_csr_->defined()) << "Both CSR are missing."; const auto& newadj = aten::CSRToCOO(out_csr_->adj(), true); if (inplace) *(const_cast(this)->coo_) = COO(meta_graph(), newadj); else ret = std::make_shared(meta_graph(), newadj); } if (inplace) { if (IsPinned()) coo_->PinMemory_(); for (auto stream : recorded_streams) coo_->RecordStream(stream); } } return ret; } aten::CSRMatrix UnitGraph::GetCSCMatrix(dgl_type_t etype) const { return GetInCSR()->adj(); } aten::CSRMatrix UnitGraph::GetCSRMatrix(dgl_type_t etype) const { return GetOutCSR()->adj(); } aten::COOMatrix UnitGraph::GetCOOMatrix(dgl_type_t etype) const { return GetCOO()->adj(); } HeteroGraphPtr UnitGraph::GetAny() const { if (in_csr_->defined()) { return in_csr_; } else if (out_csr_->defined()) { return out_csr_; } else { return coo_; } } dgl_format_code_t UnitGraph::GetCreatedFormats() const { dgl_format_code_t ret = 0; if (in_csr_->defined()) ret |= CSC_CODE; if (out_csr_->defined()) ret |= CSR_CODE; if (coo_->defined()) ret |= COO_CODE; return ret; } dgl_format_code_t UnitGraph::GetAllowedFormats() const { return formats_; } HeteroGraphPtr UnitGraph::GetFormat(SparseFormat format) const { switch (format) { case SparseFormat::kCSR: return GetOutCSR(); case SparseFormat::kCSC: return GetInCSR(); default: return GetCOO(); } } HeteroGraphPtr UnitGraph::GetGraphInFormat(dgl_format_code_t formats) const { // Get the created formats. auto created_formats = GetCreatedFormats(); // Get the intersection of formats and created_formats. auto intersection = formats & created_formats; // If the intersection of formats and created_formats is not empty. // The format(s) in the intersection will be retained. if (intersection != 0) { COOPtr coo_ptr = COO_CODE & intersection ? GetCOO(false) : nullptr; CSRPtr in_csr_ptr = CSC_CODE & intersection ? GetInCSR(false) : nullptr; CSRPtr out_csr_ptr = CSR_CODE & intersection ? GetOutCSR(false) : nullptr; return HeteroGraphPtr( new UnitGraph(meta_graph_, in_csr_ptr, out_csr_ptr, coo_ptr, formats)); } // If the intersection of formats and created_formats is empty. // Create a format in the order of COO -> CSR -> CSC. int64_t num_vtypes = NumVertexTypes(); if (COO_CODE & formats) return CreateFromCOO(num_vtypes, GetCOO(false)->adj(), formats); if (CSR_CODE & formats) return CreateFromCSR(num_vtypes, GetOutCSR(false)->adj(), formats); return CreateFromCSC(num_vtypes, GetInCSR(false)->adj(), formats); } SparseFormat UnitGraph::SelectFormat( dgl_format_code_t preferred_formats) const { dgl_format_code_t common = preferred_formats & formats_; dgl_format_code_t created = GetCreatedFormats(); if (common & created) return DecodeFormat(common & created); // NOTE(zihao): hypersparse is currently disabled since many CUDA operators on // COO have not been implmented yet. if (coo_->defined() && // coo_->IsHypersparse()) // only allow coo for hypersparse graph. // return SparseFormat::kCOO; if (common) return DecodeFormat(common); return DecodeFormat(created); } GraphPtr UnitGraph::AsImmutableGraph() const { CHECK(NumVertexTypes() == 1) << "not a homogeneous graph"; dgl::CSRPtr in_csr_ptr = nullptr, out_csr_ptr = nullptr; dgl::COOPtr coo_ptr = nullptr; if (in_csr_->defined()) { aten::CSRMatrix csc = GetCSCMatrix(0); in_csr_ptr = dgl::CSRPtr(new dgl::CSR(csc.indptr, csc.indices, csc.data)); } if (out_csr_->defined()) { aten::CSRMatrix csr = GetCSRMatrix(0); out_csr_ptr = dgl::CSRPtr(new dgl::CSR(csr.indptr, csr.indices, csr.data)); } if (coo_->defined()) { aten::COOMatrix coo = GetCOOMatrix(0); if (!COOHasData(coo)) { coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), coo.row, coo.col)); } else { IdArray new_src = Scatter(coo.row, coo.data); IdArray new_dst = Scatter(coo.col, coo.data); coo_ptr = dgl::COOPtr(new dgl::COO(NumVertices(0), new_src, new_dst)); } } return GraphPtr(new dgl::ImmutableGraph(in_csr_ptr, out_csr_ptr, coo_ptr)); } HeteroGraphPtr UnitGraph::LineGraph(bool backtracking) const { // TODO(xiangsx) currently we only support homogeneous graph auto fmt = SelectFormat(ALL_CODE); switch (fmt) { case SparseFormat::kCOO: { return CreateFromCOO(1, aten::COOLineGraph(coo_->adj(), backtracking)); } case SparseFormat::kCSR: { const aten::CSRMatrix csr = GetCSRMatrix(0); const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking); return CreateFromCOO(1, coo); } case SparseFormat::kCSC: { const aten::CSRMatrix csc = GetCSCMatrix(0); const aten::CSRMatrix csr = aten::CSRTranspose(csc); const aten::COOMatrix coo = aten::COOLineGraph(aten::CSRToCOO(csr, true), backtracking); return CreateFromCOO(1, coo); } default: LOG(FATAL) << "None of CSC, CSR, COO exist"; break; } return nullptr; } constexpr uint64_t kDGLSerialize_UnitGraphMagic = 0xDD2E60F0F6B4A127; bool UnitGraph::Load(dmlc::Stream* fs) { uint64_t magicNum; CHECK(fs->Read(&magicNum)) << "Invalid Magic Number"; CHECK_EQ(magicNum, kDGLSerialize_UnitGraphMagic) << "Invalid UnitGraph Data"; int64_t save_format_code, formats_code; CHECK(fs->Read(&save_format_code)) << "Invalid format"; CHECK(fs->Read(&formats_code)) << "Invalid format"; dgl_format_code_t save_formats = ANY_CODE; if (save_format_code >> 32) { save_formats = static_cast(0xffffffff & save_format_code); } else { save_formats = SparseFormatsToCode({static_cast(save_format_code)}); } if (formats_code >> 32) { formats_ = static_cast(0xffffffff & formats_code); } else { // NOTE(zihao): to be compatible with old formats. switch (formats_code & 0xffffffff) { case 0: formats_ = ALL_CODE; break; case 1: formats_ = COO_CODE; break; case 2: formats_ = CSR_CODE; break; case 3: formats_ = CSC_CODE; break; default: LOG(FATAL) << "Load graph failed, formats code " << formats_code << "not recognized."; } } if (save_formats & COO_CODE) { fs->Read(&coo_); } if (save_formats & CSR_CODE) { fs->Read(&out_csr_); } if (save_formats & CSC_CODE) { fs->Read(&in_csr_); } if (!coo_ && !out_csr_ && !in_csr_) { LOG(FATAL) << "unsupported format code"; } if (!in_csr_) { in_csr_ = CSRPtr(new CSR()); } if (!out_csr_) { out_csr_ = CSRPtr(new CSR()); } if (!coo_) { coo_ = COOPtr(new COO()); } meta_graph_ = GetAny()->meta_graph(); return true; } void UnitGraph::Save(dmlc::Stream* fs) const { fs->Write(kDGLSerialize_UnitGraphMagic); // Didn't write UnitGraph::meta_graph_, since it's included in the underlying // sparse matrix auto save_formats = SparseFormatsToCode({SelectFormat(ALL_CODE)}); auto fstream = dynamic_cast(fs); if (fstream) { auto formats = fstream->FormatsToSave(); save_formats = formats == ANY_CODE ? SparseFormatsToCode({SelectFormat(ALL_CODE)}) : formats; } fs->Write(static_cast(save_formats | 0x100000000)); fs->Write(static_cast(formats_ | 0x100000000)); if (save_formats & COO_CODE) { fs->Write(GetCOO()); } if (save_formats & CSR_CODE) { fs->Write(GetOutCSR()); } if (save_formats & CSC_CODE) { fs->Write(GetInCSR()); } } UnitGraphPtr UnitGraph::Reverse() const { CSRPtr new_incsr = out_csr_, new_outcsr = in_csr_; COOPtr new_coo = nullptr; if (coo_->defined()) { new_coo = COOPtr(new COO(coo_->meta_graph(), aten::COOTranspose(coo_->adj()))); } return UnitGraphPtr( new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)); } std::tuple UnitGraph::ToSimple() const { CSRPtr new_incsr = nullptr, new_outcsr = nullptr; COOPtr new_coo = nullptr; IdArray count; IdArray edge_map; auto avail_fmt = SelectFormat(ALL_CODE); switch (avail_fmt) { case SparseFormat::kCOO: { auto ret = aten::COOToSimple(GetCOO()->adj()); count = std::get<1>(ret); edge_map = std::get<2>(ret); new_coo = COOPtr(new COO(meta_graph(), std::get<0>(ret))); break; } case SparseFormat::kCSR: { auto ret = aten::CSRToSimple(GetOutCSR()->adj()); count = std::get<1>(ret); edge_map = std::get<2>(ret); new_outcsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret))); break; } case SparseFormat::kCSC: { auto ret = aten::CSRToSimple(GetInCSR()->adj()); count = std::get<1>(ret); edge_map = std::get<2>(ret); new_incsr = CSRPtr(new CSR(meta_graph(), std::get<0>(ret))); break; } default: LOG(FATAL) << "At lease one of COO, CSR or CSC adj should exist."; break; } return std::make_tuple( UnitGraphPtr(new UnitGraph(meta_graph(), new_incsr, new_outcsr, new_coo)), count, edge_map); } } // namespace dgl