Unverified Commit 1506560e authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Data] Add utils to save dict of tensors (#1481)

* add functions

* fix litn

* add unit test

* fix

* fix
parent 100ddd06
"""For Tensor Serialization"""
from __future__ import absolute_import
from .._ffi.function import _init_api
from .. import backend as F
__all__ = ['save_tensors', "load_tensors"]
_init_api("dgl.data.tensor_serialize")
def save_tensors(filename, tensor_dict):
"""
Save dict of tensors to file
Parameters
----------
filename : str
File name to store dict of tensors.
tensor_dict: dict of dgl NDArray or backend tensor
Python dict using string as key and tensor as value
"""
nd_dict = {}
for key, value in tensor_dict.items():
if not isinstance(key, str):
raise Exception("Dict key has to be str")
if F.is_tensor(value):
nd_dict[key] = F.zerocopy_to_dgl_ndarray(value)
elif isinstance(value, nd.NDArray):
nd_dict[key] = value
else:
raise Exception(
"Dict value has to be backend tensor or dgl ndarray")
return _CAPI_SaveNDArrayDict(filename, nd_dict)
def load_tensors(filename, return_dgl_ndarray=False):
"""
load dict of tensors from file
Parameters
----------
filename : str
File name to load dict of tensors.
return_dgl_ndarray: bool
Whether return dict of dgl NDArrays or backend tensors
"""
nd_dict = _CAPI_LoadNDArrayDict(filename)
tensor_dict = {}
for key, value in nd_dict.items():
if return_dgl_ndarray:
tensor_dict[key] = value.data
else:
tensor_dict[key] = F.zerocopy_from_dgl_ndarray(value.data)
return tensor_dict
...@@ -12,10 +12,11 @@ import warnings ...@@ -12,10 +12,11 @@ import warnings
import requests import requests
from .graph_serialize import save_graphs, load_graphs, load_labels from .graph_serialize import save_graphs, load_graphs, load_labels
from .tensor_serialize import save_tensors, load_tensors
__all__ = ['loadtxt','download', 'check_sha1', 'extract_archive', __all__ = ['loadtxt','download', 'check_sha1', 'extract_archive',
'get_download_dir', 'Subset', 'split_dataset', 'get_download_dir', 'Subset', 'split_dataset',
'save_graphs', "load_graphs", "load_labels"] 'save_graphs', "load_graphs", "load_labels", "save_tensors", "load_tensors"]
def loadtxt(path, delimiter, dtype=None): def loadtxt(path, delimiter, dtype=None):
try: try:
......
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file graph/graph_serialize.cc * \file graph/serialize/graph_serialize.cc
* \brief Graph serialization implementation * \brief Graph serialization implementation
* *
* The storage structure is * The storage structure is
......
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file graph/graph_serialize.h * \file graph/serialize/graph_serialize.h
* \brief Graph serialization header * \brief Graph serialization header
*/ */
#ifndef DGL_GRAPH_GRAPH_SERIALIZE_H_ #ifndef DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_
#define DGL_GRAPH_GRAPH_SERIALIZE_H_ #define DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/array.h> #include <dgl/array.h>
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "../c_api_common.h" #include "../../c_api_common.h"
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
using dgl::ImmutableGraph; using dgl::ImmutableGraph;
...@@ -112,4 +112,4 @@ ImmutableGraphPtr ToImmutableGraph(GraphPtr g); ...@@ -112,4 +112,4 @@ ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
} // namespace serialize } // namespace serialize
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_GRAPH_SERIALIZE_H_ #endif // DGL_GRAPH_SERIALIZE_GRAPH_SERIALIZE_H_
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/tensor_serialize.cc
* \brief Graph serialization implementation
*/
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dmlc/io.h>
#include "../../c_api_common.h"
using namespace dgl::runtime;
using dmlc::SeekStream;
namespace dgl {
namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor;
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
Map<std::string, Value> nd_dict = args[1];
std::vector<NamedTensor> namedTensors;
for (auto kv : nd_dict) {
NDArray ndarray = static_cast<NDArray>(kv.second->data);
namedTensors.emplace_back(kv.first, ndarray);
}
auto *fs = dynamic_cast<SeekStream *>(
SeekStream::Create(filename.c_str(), "w", true));
fs->Write(namedTensors);
delete fs;
*rv = true;
});
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0];
Map<std::string, Value> nd_dict;
std::vector<NamedTensor> namedTensors;
SeekStream *fs = SeekStream::CreateForRead(filename.c_str(), true);
CHECK(fs) << "Filename is invalid or file doesn't exists";
fs->Read(&namedTensors);
for (auto kv : namedTensors) {
Value ndarray = Value(MakeValue(kv.second));
nd_dict.Set(kv.first, ndarray);
}
delete fs;
*rv = nd_dict;
});
} // namespace serialize
} // namespace dgl
...@@ -7,7 +7,8 @@ import os ...@@ -7,7 +7,8 @@ import os
from dgl import DGLGraph from dgl import DGLGraph
import dgl import dgl
from dgl.data.utils import save_graphs, load_graphs, load_labels import dgl.ndarray as nd
from dgl.data.utils import save_graphs, load_graphs, load_labels, save_tensors, load_tensors
np.random.seed(44) np.random.seed(44)
...@@ -133,7 +134,37 @@ def test_graph_serialize_with_labels(): ...@@ -133,7 +134,37 @@ def test_graph_serialize_with_labels():
os.unlink(path) os.unlink(path)
def test_serialize_tensors():
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
tensor_dict = {"a": F.tensor(
[1, 3, -1, 0], dtype=F.int64), "1@1": F.tensor([1.5, 2], dtype=F.float32)}
save_tensors(path, tensor_dict)
load_tensor_dict = load_tensors(path)
for key in tensor_dict:
assert key in load_tensor_dict
assert np.array_equal(
F.asnumpy(load_tensor_dict[key]), F.asnumpy(tensor_dict[key]))
load_nd_dict = load_tensors(path, return_dgl_ndarray=True)
for key in tensor_dict:
assert key in load_nd_dict
assert isinstance(load_nd_dict[key], nd.NDArray)
assert np.array_equal(
load_nd_dict[key].asnumpy(), F.asnumpy(tensor_dict[key]))
os.unlink(path)
if __name__ == "__main__": if __name__ == "__main__":
test_graph_serialize_with_feature() test_graph_serialize_with_feature()
test_graph_serialize_without_feature() test_graph_serialize_without_feature()
test_graph_serialize_with_labels() test_graph_serialize_with_labels()
test_serialize_tensors()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment