/*! * Copyright (c) 2019 by Contributors * \file graph/serialize/heterograph_data.h * \brief Graph serialization header */ #ifndef DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_ #define DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include "../../c_api_common.h" #include "../heterograph.h" using dgl::runtime::NDArray; using namespace dgl::runtime; namespace dgl { namespace serialize { typedef std::pair NamedTensor; class HeteroGraphDataObject : public runtime::Object { public: std::shared_ptr gptr; std::vector> node_tensors; std::vector> edge_tensors; std::vector etype_names; std::vector ntype_names; static constexpr const char *_type_key = "heterograph_serialize.HeteroGraphData"; HeteroGraphDataObject() {} HeteroGraphDataObject(HeteroGraphPtr gptr, List> ndata, List> edata, List ntype_names, List etype_names) { this->gptr = std::dynamic_pointer_cast(gptr); CHECK_NOTNULL(this->gptr); for (auto nd_dict : ndata) { node_tensors.emplace_back(); for (auto kv : nd_dict) { auto last = &node_tensors.back(); NDArray ndarray = kv.second->data; last->emplace_back(kv.first, ndarray); } } for (auto nd_dict : edata) { edge_tensors.emplace_back(); for (auto kv : nd_dict) { auto last = &edge_tensors.back(); NDArray ndarray = kv.second->data; last->emplace_back(kv.first, ndarray); } } this->ntype_names = ListValueToVector(ntype_names); this->etype_names = ListValueToVector(etype_names); } void Save(dmlc::Stream *fs) const { fs->Write(gptr); fs->Write(node_tensors); fs->Write(edge_tensors); fs->Write(ntype_names); fs->Write(etype_names); } bool Load(dmlc::Stream *fs) { fs->Read(&gptr); fs->Read(&node_tensors); fs->Read(&edge_tensors); fs->Read(&ntype_names); fs->Read(&etype_names); return true; } DGL_DECLARE_OBJECT_TYPE_INFO(HeteroGraphDataObject, runtime::Object); }; class HeteroGraphData : public runtime::ObjectRef { public: DGL_DEFINE_OBJECT_REF_METHODS(HeteroGraphData, runtime::ObjectRef, HeteroGraphDataObject); /*! \brief create a new GraphData reference */ static HeteroGraphData Create(HeteroGraphPtr gptr, List> node_tensors, List> edge_tensors, List ntype_names, List etype_names) { return HeteroGraphData(std::make_shared( gptr, node_tensors, edge_tensors, ntype_names, etype_names)); } /*! \brief create an empty GraphData reference */ static HeteroGraphData Create() { return HeteroGraphData(std::make_shared()); } }; } // namespace serialize } // namespace dgl namespace dmlc { DMLC_DECLARE_TRAITS(has_saveload, dgl::serialize::HeteroGraphDataObject, true); } #endif // DGL_GRAPH_SERIALIZE_HETEROGRAPH_DATA_H_