Unverified Commit 30b8074a authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Fix] Fix stream creation and empty dict in tensor serialization (#1489)

* add functions

* fix litn

* add unit test

* fix

* fix

* fix

* fix

* support empty dict

* simplify logic

* Update tensor_serialize.py
parent f25bc176
"""For Tensor Serialization""" """For Tensor Serialization"""
from __future__ import absolute_import from __future__ import absolute_import
from ..ndarray import NDArray
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import backend as F from .. import backend as F
...@@ -18,20 +19,26 @@ def save_tensors(filename, tensor_dict): ...@@ -18,20 +19,26 @@ def save_tensors(filename, tensor_dict):
File name to store dict of tensors. File name to store dict of tensors.
tensor_dict: dict of dgl NDArray or backend tensor tensor_dict: dict of dgl NDArray or backend tensor
Python dict using string as key and tensor as value Python dict using string as key and tensor as value
Returns
----------
status : bool
Return whether save operation succeeds
""" """
nd_dict = {} nd_dict = {}
is_empty_dict = len(tensor_dict) == 0
for key, value in tensor_dict.items(): for key, value in tensor_dict.items():
if not isinstance(key, str): if not isinstance(key, str):
raise Exception("Dict key has to be str") raise Exception("Dict key has to be str")
if F.is_tensor(value): if F.is_tensor(value):
nd_dict[key] = F.zerocopy_to_dgl_ndarray(value) nd_dict[key] = F.zerocopy_to_dgl_ndarray(value)
elif isinstance(value, nd.NDArray): elif isinstance(value, NDArray):
nd_dict[key] = value nd_dict[key] = value
else: else:
raise Exception( raise Exception(
"Dict value has to be backend tensor or dgl ndarray") "Dict value has to be backend tensor or dgl ndarray")
return _CAPI_SaveNDArrayDict(filename, nd_dict) return _CAPI_SaveNDArrayDict(filename, nd_dict, is_empty_dict)
def load_tensors(filename, return_dgl_ndarray=False): def load_tensors(filename, return_dgl_ndarray=False):
...@@ -44,6 +51,11 @@ def load_tensors(filename, return_dgl_ndarray=False): ...@@ -44,6 +51,11 @@ def load_tensors(filename, return_dgl_ndarray=False):
File name to load dict of tensors. File name to load dict of tensors.
return_dgl_ndarray: bool return_dgl_ndarray: bool
Whether return dict of dgl NDArrays or backend tensors Whether return dict of dgl NDArrays or backend tensors
Returns
---------
tensor_dict : dict
dict of tensor or ndarray based on return_dgl_ndarray flag
""" """
nd_dict = _CAPI_LoadNDArrayDict(filename) nd_dict = _CAPI_LoadNDArrayDict(filename)
tensor_dict = {} tensor_dict = {}
......
...@@ -19,36 +19,48 @@ namespace serialize { ...@@ -19,36 +19,48 @@ namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor; typedef std::pair<std::string, NDArray> NamedTensor;
constexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F;
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict") DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
Map<std::string, Value> nd_dict = args[1]; auto *fs = dmlc::Stream::Create(filename.c_str(), "w");
CHECK(fs) << "Filename is invalid";
fs->Write(kDGLSerialize_Tensors);
bool empty_dict = args[2];
Map<std::string, Value> nd_dict;
if (!empty_dict) {
nd_dict = args[1];
}
std::vector<NamedTensor> namedTensors; std::vector<NamedTensor> namedTensors;
fs->Write(static_cast<uint64_t>(nd_dict.size()));
for (auto kv : nd_dict) { for (auto kv : nd_dict) {
NDArray ndarray = static_cast<NDArray>(kv.second->data); NDArray ndarray = static_cast<NDArray>(kv.second->data);
namedTensors.emplace_back(kv.first, ndarray); namedTensors.emplace_back(kv.first, ndarray);
} }
auto *fs = dynamic_cast<SeekStream *>(
SeekStream::Create(filename.c_str(), "w", true));
fs->Write(namedTensors); fs->Write(namedTensors);
delete fs;
*rv = true; *rv = true;
delete fs;
}); });
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict") DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
auto *fs = dmlc::Stream::Create(filename.c_str(), "r");
CHECK(fs) << "Filename is invalid or file doesn't exists";
uint64_t magincNum, num_elements;
CHECK(fs->Read(&magincNum)) << "Invalid file";
CHECK_EQ(magincNum, kDGLSerialize_Tensors) << "Invalid DGL tensor file";
CHECK(fs->Read(&num_elements)) << "Invalid num of elements";
Map<std::string, Value> nd_dict; Map<std::string, Value> nd_dict;
std::vector<NamedTensor> namedTensors; std::vector<NamedTensor> namedTensors;
SeekStream *fs = SeekStream::CreateForRead(filename.c_str(), true);
CHECK(fs) << "Filename is invalid or file doesn't exists";
fs->Read(&namedTensors); fs->Read(&namedTensors);
for (auto kv : namedTensors) { for (auto kv : namedTensors) {
Value ndarray = Value(MakeValue(kv.second)); Value ndarray = Value(MakeValue(kv.second));
nd_dict.Set(kv.first, ndarray); nd_dict.Set(kv.first, ndarray);
} }
delete fs;
*rv = nd_dict; *rv = nd_dict;
delete fs;
}); });
} // namespace serialize } // namespace serialize
......
...@@ -162,9 +162,25 @@ def test_serialize_tensors(): ...@@ -162,9 +162,25 @@ def test_serialize_tensors():
os.unlink(path) os.unlink(path)
def test_serialize_empty_dict():
# create a temporary file and immediately release it so DGL can open it.
f = tempfile.NamedTemporaryFile(delete=False)
path = f.name
f.close()
tensor_dict = {}
save_tensors(path, tensor_dict)
load_tensor_dict = load_tensors(path)
assert isinstance(load_tensor_dict, dict)
assert len(load_tensor_dict) == 0
os.unlink(path)
if __name__ == "__main__": if __name__ == "__main__":
test_graph_serialize_with_feature() test_graph_serialize_with_feature()
test_graph_serialize_without_feature() test_graph_serialize_without_feature()
test_graph_serialize_with_labels() test_graph_serialize_with_labels()
test_serialize_tensors() test_serialize_tensors()
test_serialize_empty_dict()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment