Unverified Commit 79a51025 authored by Gan Quan's avatar Gan Quan Committed by GitHub
Browse files

Pickling support (#155)

* pickling support

* resorting to suggested way of pickling

* custom attribute pickling check

* working around a weird pytorch pickling bug

* including partial frame case

* pickling everything now

* fix as requested
parent 05f464f8
......@@ -27,6 +27,12 @@ def _load_backend():
data_type_dict = mod.__dict__[api]()
for name, dtype in data_type_dict.items():
setattr(thismod, name, dtype)
# override data type dict function
setattr(thismod, 'data_type_dict', data_type_dict)
setattr(thismod,
'reverse_data_type_dict',
{v: k for k, v in data_type_dict.items()})
else:
# load functions
if api in mod.__dict__:
......
......@@ -8,6 +8,8 @@ from . import backend as F
from .base import DGLError, dgl_warning
from . import utils
import sys
class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
"""The column scheme.
......@@ -19,7 +21,21 @@ class Scheme(namedtuple('Scheme', ['shape', 'dtype'])):
dtype : TVMType
The feature data type.
"""
pass
# FIXME:
# Python 3.5.2 is unable to pickle torch dtypes; this is a workaround.
# I also have to create data_type_dict and reverse_data_type_dict
# attribute just for this bug.
# I raised an issue in PyTorch bug tracker:
# https://github.com/pytorch/pytorch/issues/14057
if sys.version_info.major == 3 and sys.version_info.minor == 5:
def __reduce__(self):
return self._reconstruct_scheme, \
(self.shape, F.reverse_data_type_dict[self.dtype])
@classmethod
def _reconstruct_scheme(cls, shape, dtype_str):
dtype = F.data_type_dict[dtype_str]
return cls(shape, dtype)
def infer_scheme(tensor):
return Scheme(tuple(F.shape(tensor)[1:]), F.dtype(tensor))
......@@ -143,6 +159,11 @@ class Column(object):
else:
return Column(data)
def _default_zero_initializer(shape, dtype, ctx):
return F.zeros(shape, dtype, ctx)
class Frame(MutableMapping):
"""The columnar storage for node/edge features.
......@@ -183,7 +204,7 @@ class Frame(MutableMapping):
dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
self._initializer = lambda shape, dtype, ctx: F.zeros(shape, dtype, ctx)
self._initializer = _default_zero_initializer
def set_initializer(self, initializer):
"""Set the initializer for empty values.
......
......@@ -641,6 +641,26 @@ class GraphIndex(object):
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle)
def __getstate__(self):
src, dst, _ = self.edges()
n_nodes = self.number_of_nodes()
multigraph = self.is_multigraph()
return n_nodes, multigraph, src, dst
def __setstate__(self, state):
"""The pickle state of GraphIndex is defined as a triplet
(number_of_nodes, multigraph, src_nodes, dst_nodes)
"""
n_nodes, multigraph, src, dst = state
self._handle = _CAPI_DGLGraphCreate(multigraph)
self._cache = {}
self.clear()
self.add_nodes(n_nodes)
self.add_edges(src, dst)
class SubgraphIndex(GraphIndex):
"""Graph index for subgraph.
......@@ -695,6 +715,14 @@ class SubgraphIndex(GraphIndex):
"""
return self._induced_edges
def __getstate__(self):
raise NotImplementedError(
"SubgraphIndex pickling is not supported yet.")
def __setstate__(self, state):
raise NotImplementedError(
"SubgraphIndex unpickling is not supported yet.")
def map_to_subgraph_nid(subgraph, parent_nids):
"""Map parent node Ids to the subgraph node Ids.
......
......@@ -11,6 +11,9 @@ from . import ndarray as nd
class Index(object):
"""Index class that can be easily converted to list/tensor."""
def __init__(self, data):
self._initialize_data(data)
def _initialize_data(self, data):
self._list_data = None # a numpy type data
self._user_tensor_data = dict() # dictionary of user tensors
self._dgl_tensor_data = None # a dgl ndarray
......@@ -93,6 +96,12 @@ class Index(object):
def __getitem__(self, i):
return self.tolist()[i]
def __getstate__(self):
return self.tousertensor()
def __setstate__(self, state):
self._initialize_data(state)
def toindex(x):
return x if isinstance(x, Index) else Index(x)
......
import dgl
from dgl.frame import Frame, FrameRef, Column
from dgl.graph_index import create_graph_index
from dgl.utils import toindex
import dgl.backend as backend
import dgl.function as F
import utils as U
import torch
import pickle
import io
def _reconstruct_pickle(obj):
f = io.BytesIO()
pickle.dump(obj, f)
f.seek(0)
obj = pickle.load(f)
f.close()
return obj
def test_pickling_index():
i = toindex([1, 2, 3])
i.tousertensor()
i.todgltensor() # construct a dgl tensor which is unpicklable
i2 = _reconstruct_pickle(i)
assert torch.equal(i2.tousertensor(), i.tousertensor())
def test_pickling_graph_index():
gi = create_graph_index()
gi.add_nodes(3)
src_idx = toindex([0, 0])
dst_idx = toindex([1, 2])
gi.add_edges(src_idx, dst_idx)
gi2 = _reconstruct_pickle(gi)
assert gi2.number_of_nodes() == gi.number_of_nodes()
src_idx2, dst_idx2, _ = gi2.edges()
assert torch.equal(src_idx.tousertensor(), src_idx2.tousertensor())
assert torch.equal(dst_idx.tousertensor(), dst_idx2.tousertensor())
def test_pickling_frame():
x = torch.randn(3, 7)
y = torch.randn(3, 5)
c = Column(x)
c2 = _reconstruct_pickle(c)
assert U.allclose(c.data, c2.data)
fr = Frame({'x': x, 'y': y})
fr2 = _reconstruct_pickle(fr)
assert U.allclose(fr2['x'].data, x)
assert U.allclose(fr2['y'].data, y)
fr = Frame()
def _assert_is_identical(g, g2):
assert g.number_of_nodes() == g2.number_of_nodes()
src, dst = g.all_edges()
src2, dst2 = g2.all_edges()
assert torch.equal(src, src2)
assert torch.equal(dst, dst2)
assert len(g.ndata) == len(g2.ndata)
assert len(g.edata) == len(g2.edata)
for k in g.ndata:
assert U.allclose(g.ndata[k], g2.ndata[k])
for k in g.edata:
assert U.allclose(g.edata[k], g2.edata[k])
def _global_message_func(nodes):
return {'x': nodes.data['x']}
def test_pickling_graph():
# graph structures and frames are pickled
g = dgl.DGLGraph()
g.add_nodes(3)
src = torch.LongTensor([0, 0])
dst = torch.LongTensor([1, 2])
g.add_edges(src, dst)
x = torch.randn(3, 7)
y = torch.randn(3, 5)
a = torch.randn(2, 6)
b = torch.randn(2, 4)
g.ndata['x'] = x
g.ndata['y'] = y
g.edata['a'] = a
g.edata['b'] = b
# registered functions are pickled
g.register_message_func(_global_message_func)
reduce_func = F.sum('x', 'x')
g.register_reduce_func(reduce_func)
# custom attributes should be pickled
g.foo = 2
new_g = _reconstruct_pickle(g)
_assert_is_identical(g, new_g)
assert new_g.foo == 2
assert new_g._message_func == _global_message_func
assert isinstance(new_g._reduce_func, type(reduce_func))
assert new_g._reduce_func._name == 'sum'
assert new_g._reduce_func.op == backend.sum
assert new_g._reduce_func.msg_field == 'x'
assert new_g._reduce_func.out_field == 'x'
# test batched graph with partial set case
g2 = dgl.DGLGraph()
g2.add_nodes(4)
src2 = torch.LongTensor([0, 1])
dst2 = torch.LongTensor([2, 3])
g2.add_edges(src2, dst2)
x2 = torch.randn(4, 7)
y2 = torch.randn(3, 5)
a2 = torch.randn(2, 6)
b2 = torch.randn(2, 4)
g2.ndata['x'] = x2
g2.nodes[[0, 1, 3]].data['y'] = y2
g2.edata['a'] = a2
g2.edata['b'] = b2
bg = dgl.batch([g, g2])
bg2 = _reconstruct_pickle(bg)
_assert_is_identical(bg, bg2)
new_g, new_g2 = dgl.unbatch(bg2)
_assert_is_identical(g, new_g)
_assert_is_identical(g2, new_g2)
if __name__ == '__main__':
test_pickling_index()
test_pickling_graph_index()
test_pickling_frame()
test_pickling_graph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment