/*! * Copyright (c) 2019 by Contributors * \file graph/serialize/tensor_serialize.cc * \brief Graph serialization implementation */ #include #include #include #include #include #include "../../c_api_common.h" using namespace dgl::runtime; using dmlc::SeekStream; namespace dgl { namespace serialize { typedef std::pair NamedTensor; DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict") .set_body([](DGLArgs args, DGLRetValue *rv) { std::string filename = args[0]; Map nd_dict = args[1]; std::vector namedTensors; for (auto kv : nd_dict) { NDArray ndarray = static_cast(kv.second->data); namedTensors.emplace_back(kv.first, ndarray); } auto *fs = dynamic_cast( SeekStream::Create(filename.c_str(), "w", true)); fs->Write(namedTensors); delete fs; *rv = true; }); DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict") .set_body([](DGLArgs args, DGLRetValue *rv) { std::string filename = args[0]; Map nd_dict; std::vector namedTensors; SeekStream *fs = SeekStream::CreateForRead(filename.c_str(), true); CHECK(fs) << "Filename is invalid or file doesn't exists"; fs->Read(&namedTensors); for (auto kv : namedTensors) { Value ndarray = Value(MakeValue(kv.second)); nd_dict.Set(kv.first, ndarray); } delete fs; *rv = nd_dict; }); } // namespace serialize } // namespace dgl