Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1506560e
Unverified
Commit
1506560e
authored
Apr 28, 2020
by
Jinjing Zhou
Committed by
GitHub
Apr 27, 2020
Browse files
[Data] Add utils to save dict of tensors (#1481)
* add functions * fix litn * add unit test * fix * fix
parent
100ddd06
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
150 additions
and
8 deletions
+150
-8
python/dgl/data/tensor_serialize.py
python/dgl/data/tensor_serialize.py
+55
-0
python/dgl/data/utils.py
python/dgl/data/utils.py
+2
-1
src/graph/serialize/graph_serialize.cc
src/graph/serialize/graph_serialize.cc
+1
-1
src/graph/serialize/graph_serialize.h
src/graph/serialize/graph_serialize.h
+5
-5
src/graph/serialize/tensor_serialize.cc
src/graph/serialize/tensor_serialize.cc
+55
-0
tests/compute/test_serialize.py
tests/compute/test_serialize.py
+32
-1
No files found.
python/dgl/data/tensor_serialize.py
0 → 100644
View file @
1506560e
"""For Tensor Serialization"""
from
__future__
import
absolute_import
from
.._ffi.function
import
_init_api
from
..
import
backend
as
F
__all__
=
[
'save_tensors'
,
"load_tensors"
]
_init_api
(
"dgl.data.tensor_serialize"
)
def
save_tensors
(
filename
,
tensor_dict
):
"""
Save dict of tensors to file
Parameters
----------
filename : str
File name to store dict of tensors.
tensor_dict: dict of dgl NDArray or backend tensor
Python dict using string as key and tensor as value
"""
nd_dict
=
{}
for
key
,
value
in
tensor_dict
.
items
():
if
not
isinstance
(
key
,
str
):
raise
Exception
(
"Dict key has to be str"
)
if
F
.
is_tensor
(
value
):
nd_dict
[
key
]
=
F
.
zerocopy_to_dgl_ndarray
(
value
)
elif
isinstance
(
value
,
nd
.
NDArray
):
nd_dict
[
key
]
=
value
else
:
raise
Exception
(
"Dict value has to be backend tensor or dgl ndarray"
)
return
_CAPI_SaveNDArrayDict
(
filename
,
nd_dict
)
def
load_tensors
(
filename
,
return_dgl_ndarray
=
False
):
"""
load dict of tensors from file
Parameters
----------
filename : str
File name to load dict of tensors.
return_dgl_ndarray: bool
Whether return dict of dgl NDArrays or backend tensors
"""
nd_dict
=
_CAPI_LoadNDArrayDict
(
filename
)
tensor_dict
=
{}
for
key
,
value
in
nd_dict
.
items
():
if
return_dgl_ndarray
:
tensor_dict
[
key
]
=
value
.
data
else
:
tensor_dict
[
key
]
=
F
.
zerocopy_from_dgl_ndarray
(
value
.
data
)
return
tensor_dict
python/dgl/data/utils.py
View file @
1506560e
...
@@ -12,10 +12,11 @@ import warnings
...
@@ -12,10 +12,11 @@ import warnings
import
requests
import
requests
from
.graph_serialize
import
save_graphs
,
load_graphs
,
load_labels
from
.graph_serialize
import
save_graphs
,
load_graphs
,
load_labels
from
.tensor_serialize
import
save_tensors
,
load_tensors
__all__
=
[
'loadtxt'
,
'download'
,
'check_sha1'
,
'extract_archive'
,
__all__
=
[
'loadtxt'
,
'download'
,
'check_sha1'
,
'extract_archive'
,
'get_download_dir'
,
'Subset'
,
'split_dataset'
,
'get_download_dir'
,
'Subset'
,
'split_dataset'
,
'save_graphs'
,
"load_graphs"
,
"load_labels"
]
'save_graphs'
,
"load_graphs"
,
"load_labels"
,
"save_tensors"
,
"load_tensors"
]
def
loadtxt
(
path
,
delimiter
,
dtype
=
None
):
def
loadtxt
(
path
,
delimiter
,
dtype
=
None
):
try
:
try
:
...
...
src/graph/graph_serialize.cc
→
src/graph/
serialize/
graph_serialize.cc
View file @
1506560e
/*!
/*!
* Copyright (c) 2019 by Contributors
* Copyright (c) 2019 by Contributors
* \file graph/graph_serialize.cc
* \file graph/
serialize/
graph_serialize.cc
* \brief Graph serialization implementation
* \brief Graph serialization implementation
*
*
* The storage structure is
* The storage structure is
...
...
src/graph/graph_serialize.h
→
src/graph/
serialize/
graph_serialize.h
View file @
1506560e
/*!
/*!
* Copyright (c) 2019 by Contributors
* Copyright (c) 2019 by Contributors
* \file graph/graph_serialize.h
* \file graph/
serialize/
graph_serialize.h
* \brief Graph serialization header
* \brief Graph serialization header
*/
*/
#ifndef DGL_GRAPH_GRAPH_SERIALIZE_H_
#ifndef DGL_GRAPH_
SERIALIZE_
GRAPH_SERIALIZE_H_
#define DGL_GRAPH_GRAPH_SERIALIZE_H_
#define DGL_GRAPH_
SERIALIZE_
GRAPH_SERIALIZE_H_
#include <dgl/graph.h>
#include <dgl/graph.h>
#include <dgl/array.h>
#include <dgl/array.h>
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
#include <vector>
#include <vector>
#include <algorithm>
#include <algorithm>
#include <utility>
#include <utility>
#include "../c_api_common.h"
#include "../
../
c_api_common.h"
using
dgl
::
runtime
::
NDArray
;
using
dgl
::
runtime
::
NDArray
;
using
dgl
::
ImmutableGraph
;
using
dgl
::
ImmutableGraph
;
...
@@ -112,4 +112,4 @@ ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
...
@@ -112,4 +112,4 @@ ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
}
// namespace serialize
}
// namespace serialize
}
// namespace dgl
}
// namespace dgl
#endif // DGL_GRAPH_GRAPH_SERIALIZE_H_
#endif // DGL_GRAPH_
SERIALIZE_
GRAPH_SERIALIZE_H_
src/graph/serialize/tensor_serialize.cc
0 → 100644
View file @
1506560e
/*!
* Copyright (c) 2019 by Contributors
* \file graph/serialize/tensor_serialize.cc
* \brief Graph serialization implementation
*/
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dmlc/io.h>
#include "../../c_api_common.h"
using
namespace
dgl
::
runtime
;
using
dmlc
::
SeekStream
;
namespace
dgl
{
namespace
serialize
{
typedef
std
::
pair
<
std
::
string
,
NDArray
>
NamedTensor
;
DGL_REGISTER_GLOBAL
(
"data.tensor_serialize._CAPI_SaveNDArrayDict"
)
.
set_body
([](
DGLArgs
args
,
DGLRetValue
*
rv
)
{
std
::
string
filename
=
args
[
0
];
Map
<
std
::
string
,
Value
>
nd_dict
=
args
[
1
];
std
::
vector
<
NamedTensor
>
namedTensors
;
for
(
auto
kv
:
nd_dict
)
{
NDArray
ndarray
=
static_cast
<
NDArray
>
(
kv
.
second
->
data
);
namedTensors
.
emplace_back
(
kv
.
first
,
ndarray
);
}
auto
*
fs
=
dynamic_cast
<
SeekStream
*>
(
SeekStream
::
Create
(
filename
.
c_str
(),
"w"
,
true
));
fs
->
Write
(
namedTensors
);
delete
fs
;
*
rv
=
true
;
});
DGL_REGISTER_GLOBAL
(
"data.tensor_serialize._CAPI_LoadNDArrayDict"
)
.
set_body
([](
DGLArgs
args
,
DGLRetValue
*
rv
)
{
std
::
string
filename
=
args
[
0
];
Map
<
std
::
string
,
Value
>
nd_dict
;
std
::
vector
<
NamedTensor
>
namedTensors
;
SeekStream
*
fs
=
SeekStream
::
CreateForRead
(
filename
.
c_str
(),
true
);
CHECK
(
fs
)
<<
"Filename is invalid or file doesn't exists"
;
fs
->
Read
(
&
namedTensors
);
for
(
auto
kv
:
namedTensors
)
{
Value
ndarray
=
Value
(
MakeValue
(
kv
.
second
));
nd_dict
.
Set
(
kv
.
first
,
ndarray
);
}
delete
fs
;
*
rv
=
nd_dict
;
});
}
// namespace serialize
}
// namespace dgl
tests/compute/test_
graph_
serialize.py
→
tests/compute/test_serialize.py
View file @
1506560e
...
@@ -7,7 +7,8 @@ import os
...
@@ -7,7 +7,8 @@ import os
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
import
dgl
import
dgl
from
dgl.data.utils
import
save_graphs
,
load_graphs
,
load_labels
import
dgl.ndarray
as
nd
from
dgl.data.utils
import
save_graphs
,
load_graphs
,
load_labels
,
save_tensors
,
load_tensors
np
.
random
.
seed
(
44
)
np
.
random
.
seed
(
44
)
...
@@ -133,7 +134,37 @@ def test_graph_serialize_with_labels():
...
@@ -133,7 +134,37 @@ def test_graph_serialize_with_labels():
os
.
unlink
(
path
)
os
.
unlink
(
path
)
def
test_serialize_tensors
():
# create a temporary file and immediately release it so DGL can open it.
f
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
path
=
f
.
name
f
.
close
()
tensor_dict
=
{
"a"
:
F
.
tensor
(
[
1
,
3
,
-
1
,
0
],
dtype
=
F
.
int64
),
"1@1"
:
F
.
tensor
([
1.5
,
2
],
dtype
=
F
.
float32
)}
save_tensors
(
path
,
tensor_dict
)
load_tensor_dict
=
load_tensors
(
path
)
for
key
in
tensor_dict
:
assert
key
in
load_tensor_dict
assert
np
.
array_equal
(
F
.
asnumpy
(
load_tensor_dict
[
key
]),
F
.
asnumpy
(
tensor_dict
[
key
]))
load_nd_dict
=
load_tensors
(
path
,
return_dgl_ndarray
=
True
)
for
key
in
tensor_dict
:
assert
key
in
load_nd_dict
assert
isinstance
(
load_nd_dict
[
key
],
nd
.
NDArray
)
assert
np
.
array_equal
(
load_nd_dict
[
key
].
asnumpy
(),
F
.
asnumpy
(
tensor_dict
[
key
]))
os
.
unlink
(
path
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_graph_serialize_with_feature
()
test_graph_serialize_with_feature
()
test_graph_serialize_without_feature
()
test_graph_serialize_without_feature
()
test_graph_serialize_with_labels
()
test_graph_serialize_with_labels
()
test_serialize_tensors
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment