/*! * Copyright (c) 2019 by Contributors * \file graph/serialize/graph_serialize.cc * \brief Graph serialization implementation * * The storage structure is * { * // MetaData Section * uint64_t kDGLSerializeMagic * uint64_t kVersion * uint64_t GraphType * ** Reserved Area till 4kB ** * * dgl_id_t num_graphs * vector graph_indices (start address of each graph) * vector nodes_num_list (list of number of nodes for each graph) * vector edges_num_list (list of number of edges for each graph) * * vector graph_datas; * * } * * Storage of GraphData is * { * // Everything uses in csr * NDArray indptr * NDArray indices * NDArray edge_ids * vector> node_tensors; * vector> edge_tensors; * } * */ #include #include #include #include #include #include #include #include #include #include #include #include #include "graph_serialize.h" using namespace dgl::runtime; using dgl::COO; using dgl::COOPtr; using dgl::ImmutableGraph; using dgl::runtime::NDArray; using dgl::serialize::GraphData; using dgl::serialize::GraphDataObject; using dmlc::SeekStream; using std::vector; namespace dmlc { DMLC_DECLARE_TRAITS(has_saveload, GraphDataObject, true); } namespace dgl { namespace serialize { bool SaveDGLGraphs(std::string filename, List graph_data, std::vector labels_list) { auto fs = std::unique_ptr(dynamic_cast( SeekStream::Create(filename.c_str(), "w", true))); CHECK(fs) << "File name " << filename << " is not a valid local file name"; // Write DGL MetaData const uint64_t kVersion = 1; fs->Write(kDGLSerializeMagic); fs->Write(kVersion); fs->Write(GraphType::kImmutableGraph); fs->Seek(4096); // Write Graph Meta Data dgl_id_t num_graph = graph_data.size(); std::vector graph_indices(num_graph); std::vector nodes_num_list(num_graph); std::vector edges_num_list(num_graph); for (uint64_t i = 0; i < num_graph; ++i) { nodes_num_list[i] = graph_data[i]->gptr->NumVertices(); edges_num_list[i] = graph_data[i]->gptr->NumEdges(); } // Reserve spaces for graph indices fs->Write(num_graph); dgl_id_t indices_start_ptr = fs->Tell(); fs->Write(graph_indices); fs->Write(nodes_num_list); fs->Write(edges_num_list); fs->Write(labels_list); // Write GraphData for (uint64_t i = 0; i < num_graph; ++i) { graph_indices[i] = fs->Tell(); GraphDataObject gdata = *graph_data[i].as(); fs->Write(gdata); } fs->Seek(indices_start_ptr); fs->Write(graph_indices); return true; } StorageMetaData LoadDGLGraphs(const std::string &filename, std::vector idx_list, bool onlyMeta) { auto fs = std::unique_ptr( SeekStream::CreateForRead(filename.c_str(), true)); CHECK(fs) << "Filename is invalid"; // Read DGL MetaData uint64_t magicNum, graphType, version; fs->Read(&magicNum); fs->Read(&version); fs->Read(&graphType); fs->Seek(4096); CHECK_EQ(magicNum, kDGLSerializeMagic) << "Invalid DGL files"; CHECK_EQ(version, 1) << "Invalid DGL files"; StorageMetaData metadata = StorageMetaData::Create(); // Read Graph MetaData dgl_id_t num_graph; CHECK(fs->Read(&num_graph)) << "Invalid num of graph"; std::vector graph_indices; std::vector nodes_num_list; std::vector edges_num_list; std::vector labels_list; CHECK(fs->Read(&graph_indices)) << "Invalid graph indices"; CHECK(fs->Read(&nodes_num_list)) << "Invalid node num list"; CHECK(fs->Read(&edges_num_list)) << "Invalid edge num list"; CHECK(fs->Read(&labels_list)) << "Invalid label list"; metadata->SetMetaData(num_graph, nodes_num_list, edges_num_list, labels_list); std::vector gdata_refs; // Early Return if (onlyMeta) { return metadata; } if (idx_list.empty()) { // Read All Graphs gdata_refs.reserve(num_graph); for (uint64_t i = 0; i < num_graph; ++i) { GraphData gdata = GraphData::Create(); GraphDataObject *gdata_ptr = const_cast(gdata.as()); fs->Read(gdata_ptr); gdata_refs.push_back(gdata); } } else { // Read Selected Graphss gdata_refs.reserve(idx_list.size()); // Would be better if idx_list is sorted. However the returned the graphs // should be the same order as the idx_list for (uint64_t i = 0; i < idx_list.size(); ++i) { auto gid = idx_list[i]; CHECK((gid < graph_indices.size()) && (gid >= 0)) << "ID " << gid << " in idx_list is out of bound. Please check your idx_list."; fs->Seek(graph_indices[gid]); GraphData gdata = GraphData::Create(); GraphDataObject *gdata_ptr = const_cast(gdata.as()); fs->Read(gdata_ptr); gdata_refs.push_back(gdata); } } metadata->SetGraphData(gdata_refs); return metadata; } void GraphDataObject::SetData(ImmutableGraphPtr gptr, Map node_tensors, Map edge_tensors) { this->gptr = gptr; for (auto kv : node_tensors) { std::string name = kv.first; Value v = kv.second; NDArray ndarray = static_cast(v->data); this->node_tensors.emplace_back(name, ndarray); } for (auto kv : edge_tensors) { std::string &name = kv.first; Value v = kv.second; const NDArray &ndarray = static_cast(v->data); this->edge_tensors.emplace_back(name, ndarray); } } void GraphDataObject::Save(dmlc::Stream *fs) const { // Using in csr for storage const CSRPtr g_csr = this->gptr->GetInCSR(); fs->Write(g_csr->indptr()); fs->Write(g_csr->indices()); fs->Write(g_csr->edge_ids()); fs->Write(node_tensors); fs->Write(edge_tensors); } bool GraphDataObject::Load(dmlc::Stream *fs) { NDArray indptr, indices, edge_ids; fs->Read(&indptr); fs->Read(&indices); fs->Read(&edge_ids); this->gptr = ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, "in"); fs->Read(&this->node_tensors); fs->Read(&this->edge_tensors); return true; } ImmutableGraphPtr BatchLoadedGraphs(std::vector gdata_list) { std::vector gptrs; gptrs.reserve(gdata_list.size()); for (auto gdata : gdata_list) { gptrs.push_back(static_cast(gdata->gptr)); } ImmutableGraphPtr imGPtr = std::dynamic_pointer_cast(GraphOp::DisjointUnion(gptrs)); return imGPtr; } ImmutableGraphPtr ToImmutableGraph(GraphPtr g) { ImmutableGraphPtr imgr = std::dynamic_pointer_cast(g); if (imgr) { return imgr; } else { MutableGraphPtr mgr = std::dynamic_pointer_cast(g); CHECK(mgr) << "Invalid Graph Pointer"; EdgeArray earray = mgr->Edges("eid"); IdArray srcs_array = earray.src; IdArray dsts_array = earray.dst; bool row_sorted, col_sorted; std::tie(row_sorted, col_sorted) = COOIsSorted( aten::COOMatrix(mgr->NumVertices(), mgr->NumVertices(), srcs_array, dsts_array)); ImmutableGraphPtr imgptr = ImmutableGraph::CreateFromCOO(mgr->NumVertices(), srcs_array, dsts_array, row_sorted, col_sorted); return imgptr; } } void StorageMetaDataObject::SetMetaData(dgl_id_t num_graph, std::vector nodes_num_list, std::vector edges_num_list, std::vector labels_list) { this->num_graph = num_graph; this->nodes_num_list = Value(MakeValue(aten::VecToIdArray(nodes_num_list))); this->edges_num_list = Value(MakeValue(aten::VecToIdArray(edges_num_list))); for (auto kv : labels_list) { this->labels_list.Set(kv.first, Value(MakeValue(kv.second))); } } void StorageMetaDataObject::SetGraphData(std::vector gdata) { this->graph_data = List(gdata); } } // namespace serialize } // namespace dgl