Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
33b9c383
Unverified
Commit
33b9c383
authored
Mar 18, 2019
by
Minjie Wang
Committed by
GitHub
Mar 18, 2019
Browse files
[Bugfix] fix bug in pickling (#456)
* fix bug in pickling * fix lint
parent
6f603bbf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
107 additions
and
55 deletions
+107
-55
python/dgl/frame.py
python/dgl/frame.py
+9
-13
python/dgl/utils.py
python/dgl/utils.py
+5
-1
tests/compute/test_pickle.py
tests/compute/test_pickle.py
+67
-41
tests/pytorch/test_pickle.py
tests/pytorch/test_pickle.py
+26
-0
No files found.
python/dgl/frame.py
View file @
33b9c383
...
...
@@ -4,7 +4,6 @@ from __future__ import absolute_import
from
collections
import
namedtuple
from
collections.abc
import
MutableMapping
import
sys
import
numpy
as
np
from
.
import
backend
as
F
...
...
@@ -22,22 +21,19 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
dtype : backend-specific type object
The feature data type.
"""
# FIXME:
# Python 3.5.2 is unable to pickle torch dtypes; this is a workaround.
# Pickling torch dtypes could be problemetic; this is a workaround.
# I also have to create data_type_dict and reverse_data_type_dict
# attribute just for this bug.
# I raised an issue in PyTorch bug tracker:
# https://github.com/pytorch/pytorch/issues/14057
if
sys
.
version_info
.
major
==
3
and
sys
.
version_info
.
minor
==
5
:
def
__reduce__
(
self
):
state
=
(
self
.
shape
,
F
.
reverse_data_type_dict
[
self
.
dtype
])
return
self
.
_reconstruct_scheme
,
state
@
classmethod
def
_reconstruct_scheme
(
cls
,
shape
,
dtype_str
):
dtype
=
F
.
data_type_dict
[
dtype_str
]
return
cls
(
shape
,
dtype
)
def
__reduce__
(
self
):
state
=
(
self
.
shape
,
F
.
reverse_data_type_dict
[
self
.
dtype
])
return
self
.
_reconstruct_scheme
,
state
@
classmethod
def
_reconstruct_scheme
(
cls
,
shape
,
dtype_str
):
dtype
=
F
.
data_type_dict
[
dtype_str
]
return
cls
(
shape
,
dtype
)
def
infer_scheme
(
tensor
):
"""Infer column scheme from the given tensor data.
...
...
python/dgl/utils.py
View file @
33b9c383
...
...
@@ -128,7 +128,11 @@ class Index(object):
return
self
.
_slice_data
==
slice
(
start
,
stop
)
def
__getstate__
(
self
):
return
self
.
tousertensor
()
if
self
.
_slice_data
is
not
None
:
# the index can be represented by a slice
return
self
.
_slice_data
else
:
return
self
.
tousertensor
()
def
__setstate__
(
self
,
state
):
self
.
_initialize_data
(
state
)
...
...
tests/compute/test_pickle.py
View file @
33b9c383
import
networkx
as
nx
import
dgl
import
dgl.contrib
as
contrib
from
dgl.frame
import
Frame
,
FrameRef
,
Column
...
...
@@ -8,6 +9,56 @@ import dgl.function as fn
import
pickle
import
io
import
torch
def
_assert_is_identical
(
g
,
g2
):
assert
g
.
is_multigraph
==
g2
.
is_multigraph
assert
g
.
is_readonly
==
g2
.
is_readonly
assert
g
.
number_of_nodes
()
==
g2
.
number_of_nodes
()
src
,
dst
=
g
.
all_edges
()
src2
,
dst2
=
g2
.
all_edges
()
assert
F
.
array_equal
(
src
,
src2
)
assert
F
.
array_equal
(
dst
,
dst2
)
assert
len
(
g
.
ndata
)
==
len
(
g2
.
ndata
)
assert
len
(
g
.
edata
)
==
len
(
g2
.
edata
)
for
k
in
g
.
ndata
:
assert
F
.
allclose
(
g
.
ndata
[
k
],
g2
.
ndata
[
k
])
for
k
in
g
.
edata
:
assert
F
.
allclose
(
g
.
edata
[
k
],
g2
.
edata
[
k
])
def
_assert_is_identical_nodeflow
(
nf1
,
nf2
):
assert
nf1
.
is_multigraph
==
nf2
.
is_multigraph
assert
nf1
.
is_readonly
==
nf2
.
is_readonly
assert
nf1
.
number_of_nodes
()
==
nf2
.
number_of_nodes
()
src
,
dst
=
nf1
.
all_edges
()
src2
,
dst2
=
nf2
.
all_edges
()
assert
F
.
array_equal
(
src
,
src2
)
assert
F
.
array_equal
(
dst
,
dst2
)
assert
nf1
.
num_layers
==
nf2
.
num_layers
for
i
in
range
(
nf1
.
num_layers
):
assert
nf1
.
layer_size
(
i
)
==
nf2
.
layer_size
(
i
)
assert
nf1
.
layers
[
i
].
data
.
keys
()
==
nf2
.
layers
[
i
].
data
.
keys
()
for
k
in
nf1
.
layers
[
i
].
data
:
assert
F
.
allclose
(
nf1
.
layers
[
i
].
data
[
k
],
nf2
.
layers
[
i
].
data
[
k
])
assert
nf1
.
num_blocks
==
nf2
.
num_blocks
for
i
in
range
(
nf1
.
num_blocks
):
assert
nf1
.
block_size
(
i
)
==
nf2
.
block_size
(
i
)
assert
nf1
.
blocks
[
i
].
data
.
keys
()
==
nf2
.
blocks
[
i
].
data
.
keys
()
for
k
in
nf1
.
blocks
[
i
].
data
:
assert
F
.
allclose
(
nf1
.
blocks
[
i
].
data
[
k
],
nf2
.
blocks
[
i
].
data
[
k
])
def
_assert_is_identical_batchedgraph
(
bg1
,
bg2
):
_assert_is_identical
(
bg1
,
bg2
)
assert
bg1
.
batch_size
==
bg2
.
batch_size
assert
bg1
.
batch_num_nodes
==
bg2
.
batch_num_nodes
assert
bg1
.
batch_num_edges
==
bg2
.
batch_num_edges
def
_assert_is_identical_index
(
i1
,
i2
):
assert
i1
.
slice_data
()
==
i2
.
slice_data
()
assert
F
.
array_equal
(
i1
.
tousertensor
(),
i2
.
tousertensor
())
def
_reconstruct_pickle
(
obj
):
f
=
io
.
BytesIO
()
pickle
.
dump
(
obj
,
f
)
...
...
@@ -18,14 +69,17 @@ def _reconstruct_pickle(obj):
return
obj
def
test_pickling_index
():
# normal index
i
=
toindex
([
1
,
2
,
3
])
i
.
tousertensor
()
i
.
todgltensor
()
# construct a dgl tensor which is unpicklable
i2
=
_reconstruct_pickle
(
i
)
_assert_is_identical_index
(
i
,
i2
)
assert
F
.
array_equal
(
i2
.
tousertensor
(),
i
.
tousertensor
())
# slice index
i
=
toindex
(
slice
(
5
,
10
))
i2
=
_reconstruct_pickle
(
i
)
_assert_is_identical_index
(
i
,
i2
)
def
test_pickling_graph_index
():
gi
=
create_graph_index
()
...
...
@@ -60,44 +114,6 @@ def test_pickling_frame():
fr
=
Frame
()
def
_assert_is_identical
(
g
,
g2
):
assert
g
.
is_multigraph
==
g2
.
is_multigraph
assert
g
.
is_readonly
==
g2
.
is_readonly
assert
g
.
number_of_nodes
()
==
g2
.
number_of_nodes
()
src
,
dst
=
g
.
all_edges
()
src2
,
dst2
=
g2
.
all_edges
()
assert
F
.
array_equal
(
src
,
src2
)
assert
F
.
array_equal
(
dst
,
dst2
)
assert
len
(
g
.
ndata
)
==
len
(
g2
.
ndata
)
assert
len
(
g
.
edata
)
==
len
(
g2
.
edata
)
for
k
in
g
.
ndata
:
assert
F
.
allclose
(
g
.
ndata
[
k
],
g2
.
ndata
[
k
])
for
k
in
g
.
edata
:
assert
F
.
allclose
(
g
.
edata
[
k
],
g2
.
edata
[
k
])
def
_assert_is_identical_nodeflow
(
nf1
,
nf2
):
assert
nf1
.
is_multigraph
==
nf2
.
is_multigraph
assert
nf1
.
is_readonly
==
nf2
.
is_readonly
assert
nf1
.
number_of_nodes
()
==
nf2
.
number_of_nodes
()
src
,
dst
=
nf1
.
all_edges
()
src2
,
dst2
=
nf2
.
all_edges
()
assert
F
.
array_equal
(
src
,
src2
)
assert
F
.
array_equal
(
dst
,
dst2
)
assert
nf1
.
num_layers
==
nf2
.
num_layers
for
i
in
range
(
nf1
.
num_layers
):
assert
nf1
.
layer_size
(
i
)
==
nf2
.
layer_size
(
i
)
assert
nf1
.
layers
[
i
].
data
.
keys
()
==
nf2
.
layers
[
i
].
data
.
keys
()
for
k
in
nf1
.
layers
[
i
].
data
:
assert
F
.
allclose
(
nf1
.
layers
[
i
].
data
[
k
],
nf2
.
layers
[
i
].
data
[
k
])
assert
nf1
.
num_blocks
==
nf2
.
num_blocks
for
i
in
range
(
nf1
.
num_blocks
):
assert
nf1
.
block_size
(
i
)
==
nf2
.
block_size
(
i
)
assert
nf1
.
blocks
[
i
].
data
.
keys
()
==
nf2
.
blocks
[
i
].
data
.
keys
()
for
k
in
nf1
.
blocks
[
i
].
data
:
assert
F
.
allclose
(
nf1
.
blocks
[
i
].
data
[
k
],
nf2
.
blocks
[
i
].
data
[
k
])
def
_global_message_func
(
nodes
):
return
{
'x'
:
nodes
.
data
[
'x'
]}
...
...
@@ -189,9 +205,19 @@ def test_pickling_nodeflow():
new_nf
=
_reconstruct_pickle
(
nf
)
_assert_is_identical_nodeflow
(
nf
,
new_nf
)
def
test_pickling_batched_graph
():
glist
=
[
nx
.
path_graph
(
i
+
5
)
for
i
in
range
(
5
)]
glist
=
[
dgl
.
DGLGraph
(
g
)
for
g
in
glist
]
bg
=
dgl
.
batch
(
glist
)
bg
.
ndata
[
'x'
]
=
F
.
randn
((
35
,
5
))
bg
.
edata
[
'y'
]
=
F
.
randn
((
60
,
3
))
new_bg
=
_reconstruct_pickle
(
bg
)
_assert_is_identical_batchedgraph
(
bg
,
new_bg
)
if
__name__
==
'__main__'
:
test_pickling_index
()
test_pickling_graph_index
()
test_pickling_frame
()
test_pickling_graph
()
test_pickling_nodeflow
()
test_pickling_batched_graph
()
tests/pytorch/test_pickle.py
0 → 100644
View file @
33b9c383
import
networkx
as
nx
import
dgl
import
torch
import
pickle
import
io
def
_reconstruct_pickle
(
obj
):
f
=
io
.
BytesIO
()
pickle
.
dump
(
obj
,
f
)
f
.
seek
(
0
)
obj
=
pickle
.
load
(
f
)
f
.
close
()
return
obj
def
test_pickling_batched_graph
():
# NOTE: this is a test for a wierd bug mentioned in
# https://github.com/dmlc/dgl/issues/438
glist
=
[
nx
.
path_graph
(
i
+
5
)
for
i
in
range
(
5
)]
glist
=
[
dgl
.
DGLGraph
(
g
)
for
g
in
glist
]
bg
=
dgl
.
batch
(
glist
)
bg
.
ndata
[
'x'
]
=
torch
.
randn
((
35
,
5
))
bg
.
edata
[
'y'
]
=
torch
.
randn
((
60
,
3
))
new_bg
=
_reconstruct_pickle
(
bg
)
if
__name__
==
'__main__'
:
test_pickling_batched_graph
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment