"vscode:/vscode.git/clone" did not exist on "6c7fad7ec8b2417c92326804e1751658874fd43b"
Unverified Commit 33b9c383 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] fix bug in pickling (#456)

* fix bug in pickling

* fix lint
parent 6f603bbf
......@@ -4,7 +4,6 @@ from __future__ import absolute_import
from collections import namedtuple
from collections.abc import MutableMapping
import sys
import numpy as np
from . import backend as F
......@@ -22,18 +21,15 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
dtype : backend-specific type object
The feature data type.
"""
# FIXME:
# Python 3.5.2 is unable to pickle torch dtypes; this is a workaround.
# Pickling torch dtypes could be problemetic; this is a workaround.
# I also have to create data_type_dict and reverse_data_type_dict
# attribute just for this bug.
# I raised an issue in PyTorch bug tracker:
# https://github.com/pytorch/pytorch/issues/14057
if sys.version_info.major == 3 and sys.version_info.minor == 5:
def __reduce__(self):
state = (self.shape, F.reverse_data_type_dict[self.dtype])
return self._reconstruct_scheme, state
@classmethod
def _reconstruct_scheme(cls, shape, dtype_str):
dtype = F.data_type_dict[dtype_str]
......
......@@ -128,6 +128,10 @@ class Index(object):
return self._slice_data == slice(start, stop)
def __getstate__(self):
if self._slice_data is not None:
# the index can be represented by a slice
return self._slice_data
else:
return self.tousertensor()
def __setstate__(self, state):
......
import networkx as nx
import dgl
import dgl.contrib as contrib
from dgl.frame import Frame, FrameRef, Column
......@@ -8,6 +9,56 @@ import dgl.function as fn
import pickle
import io
import torch
def _assert_is_identical(g, g2):
assert g.is_multigraph == g2.is_multigraph
assert g.is_readonly == g2.is_readonly
assert g.number_of_nodes() == g2.number_of_nodes()
src, dst = g.all_edges()
src2, dst2 = g2.all_edges()
assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2)
assert len(g.ndata) == len(g2.ndata)
assert len(g.edata) == len(g2.edata)
for k in g.ndata:
assert F.allclose(g.ndata[k], g2.ndata[k])
for k in g.edata:
assert F.allclose(g.edata[k], g2.edata[k])
def _assert_is_identical_nodeflow(nf1, nf2):
assert nf1.is_multigraph == nf2.is_multigraph
assert nf1.is_readonly == nf2.is_readonly
assert nf1.number_of_nodes() == nf2.number_of_nodes()
src, dst = nf1.all_edges()
src2, dst2 = nf2.all_edges()
assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2)
assert nf1.num_layers == nf2.num_layers
for i in range(nf1.num_layers):
assert nf1.layer_size(i) == nf2.layer_size(i)
assert nf1.layers[i].data.keys() == nf2.layers[i].data.keys()
for k in nf1.layers[i].data:
assert F.allclose(nf1.layers[i].data[k], nf2.layers[i].data[k])
assert nf1.num_blocks == nf2.num_blocks
for i in range(nf1.num_blocks):
assert nf1.block_size(i) == nf2.block_size(i)
assert nf1.blocks[i].data.keys() == nf2.blocks[i].data.keys()
for k in nf1.blocks[i].data:
assert F.allclose(nf1.blocks[i].data[k], nf2.blocks[i].data[k])
def _assert_is_identical_batchedgraph(bg1, bg2):
_assert_is_identical(bg1, bg2)
assert bg1.batch_size == bg2.batch_size
assert bg1.batch_num_nodes == bg2.batch_num_nodes
assert bg1.batch_num_edges == bg2.batch_num_edges
def _assert_is_identical_index(i1, i2):
assert i1.slice_data() == i2.slice_data()
assert F.array_equal(i1.tousertensor(), i2.tousertensor())
def _reconstruct_pickle(obj):
f = io.BytesIO()
pickle.dump(obj, f)
......@@ -18,14 +69,17 @@ def _reconstruct_pickle(obj):
return obj
def test_pickling_index():
# normal index
i = toindex([1, 2, 3])
i.tousertensor()
i.todgltensor() # construct a dgl tensor which is unpicklable
i2 = _reconstruct_pickle(i)
_assert_is_identical_index(i, i2)
assert F.array_equal(i2.tousertensor(), i.tousertensor())
# slice index
i = toindex(slice(5, 10))
i2 = _reconstruct_pickle(i)
_assert_is_identical_index(i, i2)
def test_pickling_graph_index():
gi = create_graph_index()
......@@ -60,44 +114,6 @@ def test_pickling_frame():
fr = Frame()
def _assert_is_identical(g, g2):
assert g.is_multigraph == g2.is_multigraph
assert g.is_readonly == g2.is_readonly
assert g.number_of_nodes() == g2.number_of_nodes()
src, dst = g.all_edges()
src2, dst2 = g2.all_edges()
assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2)
assert len(g.ndata) == len(g2.ndata)
assert len(g.edata) == len(g2.edata)
for k in g.ndata:
assert F.allclose(g.ndata[k], g2.ndata[k])
for k in g.edata:
assert F.allclose(g.edata[k], g2.edata[k])
def _assert_is_identical_nodeflow(nf1, nf2):
assert nf1.is_multigraph == nf2.is_multigraph
assert nf1.is_readonly == nf2.is_readonly
assert nf1.number_of_nodes() == nf2.number_of_nodes()
src, dst = nf1.all_edges()
src2, dst2 = nf2.all_edges()
assert F.array_equal(src, src2)
assert F.array_equal(dst, dst2)
assert nf1.num_layers == nf2.num_layers
for i in range(nf1.num_layers):
assert nf1.layer_size(i) == nf2.layer_size(i)
assert nf1.layers[i].data.keys() == nf2.layers[i].data.keys()
for k in nf1.layers[i].data:
assert F.allclose(nf1.layers[i].data[k], nf2.layers[i].data[k])
assert nf1.num_blocks == nf2.num_blocks
for i in range(nf1.num_blocks):
assert nf1.block_size(i) == nf2.block_size(i)
assert nf1.blocks[i].data.keys() == nf2.blocks[i].data.keys()
for k in nf1.blocks[i].data:
assert F.allclose(nf1.blocks[i].data[k], nf2.blocks[i].data[k])
def _global_message_func(nodes):
return {'x': nodes.data['x']}
......@@ -189,9 +205,19 @@ def test_pickling_nodeflow():
new_nf = _reconstruct_pickle(nf)
_assert_is_identical_nodeflow(nf, new_nf)
def test_pickling_batched_graph():
glist = [nx.path_graph(i + 5) for i in range(5)]
glist = [dgl.DGLGraph(g) for g in glist]
bg = dgl.batch(glist)
bg.ndata['x'] = F.randn((35, 5))
bg.edata['y'] = F.randn((60, 3))
new_bg = _reconstruct_pickle(bg)
_assert_is_identical_batchedgraph(bg, new_bg)
if __name__ == '__main__':
test_pickling_index()
test_pickling_graph_index()
test_pickling_frame()
test_pickling_graph()
test_pickling_nodeflow()
test_pickling_batched_graph()
import networkx as nx
import dgl
import torch
import pickle
import io
def _reconstruct_pickle(obj):
f = io.BytesIO()
pickle.dump(obj, f)
f.seek(0)
obj = pickle.load(f)
f.close()
return obj
def test_pickling_batched_graph():
# NOTE: this is a test for a wierd bug mentioned in
# https://github.com/dmlc/dgl/issues/438
glist = [nx.path_graph(i + 5) for i in range(5)]
glist = [dgl.DGLGraph(g) for g in glist]
bg = dgl.batch(glist)
bg.ndata['x'] = torch.randn((35, 5))
bg.edata['y'] = torch.randn((60, 3))
new_bg = _reconstruct_pickle(bg)
if __name__ == '__main__':
test_pickling_batched_graph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment