Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c13903bf
Unverified
Commit
c13903bf
authored
Jul 13, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Jul 13, 2020
Browse files
fix batched heterograph serializations (#1794)
parent
200340ab
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
0 deletions
+43
-0
python/dgl/batched_heterograph.py
python/dgl/batched_heterograph.py
+8
-0
tests/compute/test_pickle.py
tests/compute/test_pickle.py
+35
-0
No files found.
python/dgl/batched_heterograph.py
View file @
c13903bf
...
@@ -302,6 +302,14 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
...
@@ -302,6 +302,14 @@ class BatchedDGLHeteroGraph(DGLHeteroGraph):
batch_num_nodes
=
self
.
_batch_num_nodes
,
batch_num_nodes
=
self
.
_batch_num_nodes
,
batch_num_edges
=
self
.
_batch_num_edges
)
batch_num_edges
=
self
.
_batch_num_edges
)
def
__getstate__
(
self
):
state
=
super
().
__getstate__
()
return
state
,
self
.
_batch_size
,
self
.
_batch_num_nodes
,
self
.
_batch_num_edges
def
__setstate__
(
self
,
state
):
state
,
self
.
_batch_size
,
self
.
_batch_num_nodes
,
self
.
_batch_num_edges
=
state
super
().
__setstate__
(
state
)
def
unbatch_hetero
(
graph
):
def
unbatch_hetero
(
graph
):
"""Return the list of heterographs in this batch.
"""Return the list of heterographs in this batch.
...
...
tests/compute/test_pickle.py
View file @
c13903bf
...
@@ -78,6 +78,13 @@ def _assert_is_identical_batchedgraph(bg1, bg2):
...
@@ -78,6 +78,13 @@ def _assert_is_identical_batchedgraph(bg1, bg2):
assert
bg1
.
batch_num_nodes
==
bg2
.
batch_num_nodes
assert
bg1
.
batch_num_nodes
==
bg2
.
batch_num_nodes
assert
bg1
.
batch_num_edges
==
bg2
.
batch_num_edges
assert
bg1
.
batch_num_edges
==
bg2
.
batch_num_edges
def
_assert_is_identical_batchedhetero
(
bg1
,
bg2
):
_assert_is_identical_hetero
(
bg1
,
bg2
)
for
ntype
in
bg1
.
ntypes
:
assert
bg1
.
batch_num_nodes
(
ntype
)
==
bg2
.
batch_num_nodes
(
ntype
)
for
canonical_etype
in
bg1
.
canonical_etypes
:
assert
bg1
.
batch_num_edges
(
canonical_etype
)
==
bg2
.
batch_num_edges
(
canonical_etype
)
def
_assert_is_identical_index
(
i1
,
i2
):
def
_assert_is_identical_index
(
i1
,
i2
):
assert
i1
.
slice_data
()
==
i2
.
slice_data
()
assert
i1
.
slice_data
()
==
i2
.
slice_data
()
assert
F
.
array_equal
(
i1
.
tousertensor
(),
i2
.
tousertensor
())
assert
F
.
array_equal
(
i1
.
tousertensor
(),
i2
.
tousertensor
())
...
@@ -258,6 +265,33 @@ def test_pickling_heterograph():
...
@@ -258,6 +265,33 @@ def test_pickling_heterograph():
new_g
=
_reconstruct_pickle
(
g
)
new_g
=
_reconstruct_pickle
(
g
)
_assert_is_identical_hetero
(
g
,
new_g
)
_assert_is_identical_hetero
(
g
,
new_g
)
def
test_pickling_batched_heterograph
():
# copied from test_heterograph.create_test_heterograph()
plays_spmat
=
ssp
.
coo_matrix
(([
1
,
1
,
1
,
1
],
([
0
,
1
,
2
,
1
],
[
0
,
0
,
1
,
1
])))
wishes_nx
=
nx
.
DiGraph
()
wishes_nx
.
add_nodes_from
([
'u0'
,
'u1'
,
'u2'
],
bipartite
=
0
)
wishes_nx
.
add_nodes_from
([
'g0'
,
'g1'
],
bipartite
=
1
)
wishes_nx
.
add_edge
(
'u0'
,
'g1'
,
id
=
0
)
wishes_nx
.
add_edge
(
'u2'
,
'g0'
,
id
=
1
)
follows_g
=
dgl
.
graph
([(
0
,
1
),
(
1
,
2
)],
'user'
,
'follows'
)
plays_g
=
dgl
.
bipartite
(
plays_spmat
,
'user'
,
'plays'
,
'game'
)
wishes_g
=
dgl
.
bipartite
(
wishes_nx
,
'user'
,
'wishes'
,
'game'
)
develops_g
=
dgl
.
bipartite
([(
0
,
0
),
(
1
,
1
)],
'developer'
,
'develops'
,
'game'
)
g
=
dgl
.
hetero_from_relations
([
follows_g
,
plays_g
,
wishes_g
,
develops_g
])
g2
=
dgl
.
hetero_from_relations
([
follows_g
,
plays_g
,
wishes_g
,
develops_g
])
g
.
nodes
[
'user'
].
data
[
'u_h'
]
=
F
.
randn
((
3
,
4
))
g
.
nodes
[
'game'
].
data
[
'g_h'
]
=
F
.
randn
((
2
,
5
))
g
.
edges
[
'plays'
].
data
[
'p_h'
]
=
F
.
randn
((
4
,
6
))
g2
.
nodes
[
'user'
].
data
[
'u_h'
]
=
F
.
randn
((
3
,
4
))
g2
.
nodes
[
'game'
].
data
[
'g_h'
]
=
F
.
randn
((
2
,
5
))
g2
.
edges
[
'plays'
].
data
[
'p_h'
]
=
F
.
randn
((
4
,
6
))
bg
=
dgl
.
batch_hetero
([
g
,
g2
])
new_bg
=
_reconstruct_pickle
(
bg
)
_assert_is_identical_batchedhetero
(
bg
,
new_bg
)
@
unittest
.
skipIf
(
dgl
.
backend
.
backend_name
!=
"pytorch"
,
reason
=
"Only test for pytorch format file"
)
@
unittest
.
skipIf
(
dgl
.
backend
.
backend_name
!=
"pytorch"
,
reason
=
"Only test for pytorch format file"
)
def
test_pickling_heterograph_index_compatibility
():
def
test_pickling_heterograph_index_compatibility
():
plays_spmat
=
ssp
.
coo_matrix
(([
1
,
1
,
1
,
1
],
([
0
,
1
,
2
,
1
],
[
0
,
0
,
1
,
1
])))
plays_spmat
=
ssp
.
coo_matrix
(([
1
,
1
,
1
,
1
],
([
0
,
1
,
2
,
1
],
[
0
,
0
,
1
,
1
])))
...
@@ -287,3 +321,4 @@ if __name__ == '__main__':
...
@@ -287,3 +321,4 @@ if __name__ == '__main__':
test_pickling_nodeflow
()
test_pickling_nodeflow
()
test_pickling_batched_graph
()
test_pickling_batched_graph
()
test_pickling_heterograph
()
test_pickling_heterograph
()
test_pickling_batched_heterograph
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment