Unverified Commit 79a51025 authored by Gan Quan's avatar Gan Quan Committed by GitHub
Browse files

Pickling support (#155)

* pickling support

* resorting to suggested way of pickling

* custom attribute pickling check

* working around a weird pytorch pickling bug

* including partial frame case

* pickling everything now

* fix as requested
parent 05f464f8
...@@ -27,6 +27,12 @@ def _load_backend(): ...@@ -27,6 +27,12 @@ def _load_backend():
data_type_dict = mod.__dict__[api]() data_type_dict = mod.__dict__[api]()
for name, dtype in data_type_dict.items(): for name, dtype in data_type_dict.items():
setattr(thismod, name, dtype) setattr(thismod, name, dtype)
# override data type dict function
setattr(thismod, 'data_type_dict', data_type_dict)
setattr(thismod,
'reverse_data_type_dict',
{v: k for k, v in data_type_dict.items()})
else: else:
# load functions # load functions
if api in mod.__dict__: if api in mod.__dict__:
......
...@@ -8,6 +8,8 @@ from . import backend as F ...@@ -8,6 +8,8 @@ from . import backend as F
from .base import DGLError, dgl_warning from .base import DGLError, dgl_warning
from . import utils from . import utils
import sys
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme. """The column scheme.
...@@ -19,7 +21,21 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])): ...@@ -19,7 +21,21 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
dtype : TVMType dtype : TVMType
The feature data type. The feature data type.
""" """
pass # FIXME:
# Python 3.5.2 is unable to pickle torch dtypes; this is a workaround.
# I also have to create data_type_dict and reverse_data_type_dict
# attribute just for this bug.
# I raised an issue in PyTorch bug tracker:
# https://github.com/pytorch/pytorch/issues/14057
if sys.version_info.major == 3 and sys.version_info.minor == 5:
def __reduce__(self):
return self._reconstruct_scheme, \
(self.shape, F.reverse_data_type_dict[self.dtype])
@classmethod
def _reconstruct_scheme(cls, shape, dtype_str):
dtype = F.data_type_dict[dtype_str]
return cls(shape, dtype)
def infer_scheme(tensor): def infer_scheme(tensor):
return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor)) return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))
...@@ -143,6 +159,11 @@ class Column(object): ...@@ -143,6 +159,11 @@ class Column(object):
else: else:
return Column(data) return Column(data)
def _default_zero_initializer(shape, dtype, ctx):
return F.zeros(shape, dtype, ctx)
class Frame(MutableMapping): class Frame(MutableMapping):
"""The columnar storage for node/edge features. """The columnar storage for node/edge features.
...@@ -183,7 +204,7 @@ class Frame(MutableMapping): ...@@ -183,7 +204,7 @@ class Frame(MutableMapping):
dgl_warning('Initializer is not set. Use zero initializer instead.' dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to' ' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.') ' explicitly specify which initializer to use.')
self._initializer = lambda shape, dtype, ctx: F.zeros(shape, dtype, ctx) self._initializer = _default_zero_initializer
def set_initializer(self, initializer): def set_initializer(self, initializer):
"""Set the initializer for empty values. """Set the initializer for empty values.
......
...@@ -641,6 +641,26 @@ class GraphIndex(object): ...@@ -641,6 +641,26 @@ class GraphIndex(object):
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking) handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle) return GraphIndex(handle)
def __getstate__(self):
src, dst, _ = self.edges()
n_nodes = self.number_of_nodes()
multigraph = self.is_multigraph()
return n_nodes, multigraph, src, dst
def __setstate__(self, state):
"""The pickle state of GraphIndex is defined as a triplet
(number_of_nodes, multigraph, src_nodes, dst_nodes)
"""
n_nodes, multigraph, src, dst = state
self._handle = _CAPI_DGLGraphCreate(multigraph)
self._cache = {}
self.clear()
self.add_nodes(n_nodes)
self.add_edges(src, dst)
class SubgraphIndex(GraphIndex): class SubgraphIndex(GraphIndex):
"""Graph index for subgraph. """Graph index for subgraph.
...@@ -695,6 +715,14 @@ class SubgraphIndex(GraphIndex): ...@@ -695,6 +715,14 @@ class SubgraphIndex(GraphIndex):
""" """
return self._induced_edges return self._induced_edges
def __getstate__(self):
raise NotImplementedError(
"SubgraphIndex pickling is not supported yet.")
def __setstate__(self, state):
raise NotImplementedError(
"SubgraphIndex unpickling is not supported yet.")
def map_to_subgraph_nid(subgraph, parent_nids): def map_to_subgraph_nid(subgraph, parent_nids):
"""Map parent node Ids to the subgraph node Ids. """Map parent node Ids to the subgraph node Ids.
......
...@@ -11,6 +11,9 @@ from . import ndarray as nd ...@@ -11,6 +11,9 @@ from . import ndarray as nd
class Index(object): class Index(object):
"""Index class that can be easily converted to list/tensor.""" """Index class that can be easily converted to list/tensor."""
def __init__(self, data): def __init__(self, data):
self._initialize_data(data)
def _initialize_data(self, data):
self._list_data = None # a numpy type data self._list_data = None # a numpy type data
self._user_tensor_data = dict() # dictionary of user tensors self._user_tensor_data = dict() # dictionary of user tensors
self._dgl_tensor_data = None # a dgl ndarray self._dgl_tensor_data = None # a dgl ndarray
...@@ -93,6 +96,12 @@ class Index(object): ...@@ -93,6 +96,12 @@ class Index(object):
def __getitem__(self, i): def __getitem__(self, i):
return self.tolist()[i] return self.tolist()[i]
def __getstate__(self):
return self.tousertensor()
def __setstate__(self, state):
self._initialize_data(state)
def toindex(x): def toindex(x):
return x if isinstance(x, Index) else Index(x) return x if isinstance(x, Index) else Index(x)
......
import dgl
from dgl.frame import Frame, FrameRef, Column
from dgl.graph_index import create_graph_index
from dgl.utils import toindex
import dgl.backend as backend
import dgl.function as F
import utils as U
import torch
import pickle
import io
def _reconstruct_pickle(obj):
f = io.BytesIO()
pickle.dump(obj, f)
f.seek(0)
obj = pickle.load(f)
f.close()
return obj
def test_pickling_index():
i = toindex([1, 2, 3])
i.tousertensor()
i.todgltensor() # construct a dgl tensor which is unpicklable
i2 = _reconstruct_pickle(i)
assert torch.equal(i2.tousertensor(), i.tousertensor())
def test_pickling_graph_index():
gi = create_graph_index()
gi.add_nodes(3)
src_idx = toindex([0, 0])
dst_idx = toindex([1, 2])
gi.add_edges(src_idx, dst_idx)
gi2 = _reconstruct_pickle(gi)
assert gi2.number_of_nodes() == gi.number_of_nodes()
src_idx2, dst_idx2, _ = gi2.edges()
assert torch.equal(src_idx.tousertensor(), src_idx2.tousertensor())
assert torch.equal(dst_idx.tousertensor(), dst_idx2.tousertensor())
def test_pickling_frame():
x = torch.randn(3, 7)
y = torch.randn(3, 5)
c = Column(x)
c2 = _reconstruct_pickle(c)
assert U.allclose(c.data, c2.data)
fr = Frame({'x': x, 'y': y})
fr2 = _reconstruct_pickle(fr)
assert U.allclose(fr2['x'].data, x)
assert U.allclose(fr2['y'].data, y)
fr = Frame()
def _assert_is_identical(g, g2):
assert g.number_of_nodes() == g2.number_of_nodes()
src, dst = g.all_edges()
src2, dst2 = g2.all_edges()
assert torch.equal(src, src2)
assert torch.equal(dst, dst2)
assert len(g.ndata) == len(g2.ndata)
assert len(g.edata) == len(g2.edata)
for k in g.ndata:
assert U.allclose(g.ndata[k], g2.ndata[k])
for k in g.edata:
assert U.allclose(g.edata[k], g2.edata[k])
def _global_message_func(nodes):
return {'x': nodes.data['x']}
def test_pickling_graph():
# graph structures and frames are pickled
g = dgl.DGLGraph()
g.add_nodes(3)
src = torch.LongTensor([0, 0])
dst = torch.LongTensor([1, 2])
g.add_edges(src, dst)
x = torch.randn(3, 7)
y = torch.randn(3, 5)
a = torch.randn(2, 6)
b = torch.randn(2, 4)
g.ndata['x'] = x
g.ndata['y'] = y
g.edata['a'] = a
g.edata['b'] = b
# registered functions are pickled
g.register_message_func(_global_message_func)
reduce_func = F.sum('x', 'x')
g.register_reduce_func(reduce_func)
# custom attributes should be pickled
g.foo = 2
new_g = _reconstruct_pickle(g)
_assert_is_identical(g, new_g)
assert new_g.foo == 2
assert new_g._message_func == _global_message_func
assert isinstance(new_g._reduce_func, type(reduce_func))
assert new_g._reduce_func._name == 'sum'
assert new_g._reduce_func.op == backend.sum
assert new_g._reduce_func.msg_field == 'x'
assert new_g._reduce_func.out_field == 'x'
# test batched graph with partial set case
g2 = dgl.DGLGraph()
g2.add_nodes(4)
src2 = torch.LongTensor([0, 1])
dst2 = torch.LongTensor([2, 3])
g2.add_edges(src2, dst2)
x2 = torch.randn(4, 7)
y2 = torch.randn(3, 5)
a2 = torch.randn(2, 6)
b2 = torch.randn(2, 4)
g2.ndata['x'] = x2
g2.nodes[[0, 1, 3]].data['y'] = y2
g2.edata['a'] = a2
g2.edata['b'] = b2
bg = dgl.batch([g, g2])
bg2 = _reconstruct_pickle(bg)
_assert_is_identical(bg, bg2)
new_g, new_g2 = dgl.unbatch(bg2)
_assert_is_identical(g, new_g)
_assert_is_identical(g2, new_g2)
if __name__ == '__main__':
test_pickling_index()
test_pickling_graph_index()
test_pickling_frame()
test_pickling_graph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment